test: unit test cases for rag.cleaner, rag.data_post_processor and rag.datasource (#32521)

This commit is contained in:
Rajat Agarwal
2026-03-24 23:49:15 +05:30
committed by GitHub
parent 36cc1bf025
commit 6f137fdb00
50 changed files with 13766 additions and 47 deletions

View File

@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
docs.append(doc)
return docs

View File

@ -211,3 +211,16 @@ class TestCleanProcessor:
text = "[Text with (parens) and symbols](https://example.com)"
expected = "[Text with (parens) and symbols](https://example.com)"
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_remove_urls_emails_preserves_markdown_image_links(self):
"""Remove plain URLs and emails while preserving markdown image links."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)"
result = CleanProcessor.clean(text, process_rule)
assert result == "Email and remove but keep ![diagram](https://example.com/image.png)"
def test_filter_string_returns_input_text(self):
"""Test filter_string passthrough behavior."""
processor = CleanProcessor()
assert processor.filter_string("raw text") == "raw text"

View File

@ -0,0 +1,249 @@
from unittest.mock import MagicMock, patch
import pytest
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
def _doc(content: str) -> Document:
return Document(page_content=content)
class TestDataPostProcessor:
def test_init_sets_rerank_and_reorder_runners(self):
rerank_runner = object()
reorder_runner = object()
with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock:
with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock:
processor = DataPostProcessor(
tenant_id="tenant-1",
reranking_mode=RerankMode.WEIGHTED_SCORE,
reranking_model={"config": "value"},
weights={"weight": "value"},
reorder_enabled=True,
)
assert processor.rerank_runner is rerank_runner
assert processor.reorder_runner is reorder_runner
rerank_mock.assert_called_once_with(
RerankMode.WEIGHTED_SCORE,
"tenant-1",
{"config": "value"},
{"weight": "value"},
)
reorder_mock.assert_called_once_with(True)
def test_invoke_applies_rerank_then_reorder(self):
original_documents = [_doc("doc-a")]
reranked_documents = [_doc("doc-b")]
reordered_documents = [_doc("doc-c")]
processor = DataPostProcessor.__new__(DataPostProcessor)
processor.rerank_runner = MagicMock()
processor.rerank_runner.run.return_value = reranked_documents
processor.reorder_runner = MagicMock()
processor.reorder_runner.run.return_value = reordered_documents
result = processor.invoke(
query="how to test",
documents=original_documents,
score_threshold=0.3,
top_n=2,
user="user-1",
query_type=QueryType.IMAGE_QUERY,
)
processor.rerank_runner.run.assert_called_once_with(
"how to test",
original_documents,
0.3,
2,
"user-1",
QueryType.IMAGE_QUERY,
)
processor.reorder_runner.run.assert_called_once_with(reranked_documents)
assert result == reordered_documents
def test_invoke_returns_original_documents_when_no_runner_is_configured(self):
documents = [_doc("doc-a"), _doc("doc-b")]
processor = DataPostProcessor.__new__(DataPostProcessor)
processor.rerank_runner = None
processor.reorder_runner = None
assert processor.invoke(query="query", documents=documents) == documents
def test_get_rerank_runner_for_weighted_score(self):
weights_config = {
"vector_setting": {
"vector_weight": 0.7,
"embedding_provider_name": "provider-x",
"embedding_model_name": "embedding-y",
},
"keyword_setting": {"keyword_weight": 0.3},
}
expected_runner = object()
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
return_value=expected_runner,
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.WEIGHTED_SCORE,
tenant_id="tenant-1",
reranking_model=None,
weights=weights_config,
)
assert result is expected_runner
kwargs = factory_mock.call_args.kwargs
assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE
assert kwargs["tenant_id"] == "tenant-1"
assert kwargs["weights"].vector_setting.vector_weight == 0.7
assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x"
assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y"
assert kwargs["weights"].keyword_setting.keyword_weight == 0.3
def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
reranking_model = {
"reranking_provider_name": "provider-x",
"reranking_model_name": "model-y",
}
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock:
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner"
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.RERANKING_MODEL,
tenant_id="tenant-1",
reranking_model=reranking_model,
weights=None,
)
assert result is None
model_mock.assert_called_once_with("tenant-1", reranking_model)
factory_mock.assert_not_called()
def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
model_instance = object()
expected_runner = object()
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance):
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
return_value=expected_runner,
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.RERANKING_MODEL,
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "model-y",
},
weights=None,
)
assert result is expected_runner
factory_mock.assert_called_once_with(
runner_type=RerankMode.RERANKING_MODEL,
rerank_model_instance=model_instance,
)
def test_get_rerank_runner_returns_none_for_unsupported_mode(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None
assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None
def test_get_reorder_runner_by_flag(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert isinstance(processor._get_reorder_runner(True), ReorderRunner)
assert processor._get_reorder_runner(False) is None
def test_get_rerank_model_instance_returns_none_when_config_is_missing(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert processor._get_rerank_model_instance("tenant-1", None) is None
def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
with pytest.raises(KeyError, match="reranking_model_name"):
processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={"reranking_provider_name": "provider-x"},
)
manager_instance.get_model_instance.assert_not_called()
def test_get_rerank_model_instance_success(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
model_instance = object()
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
manager_instance.get_model_instance.return_value = model_instance
result = processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "reranker-1",
},
)
assert result is model_instance
manager_instance.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="provider-x",
model_type=ModelType.RERANK,
model="reranker-1",
)
def test_get_rerank_model_instance_handles_authorization_error(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized")
result = processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "reranker-1",
},
)
assert result is None
class TestReorderRunner:
def test_run_reorders_even_sized_document_list(self):
documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")]
reordered = ReorderRunner().run(documents)
assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"]
def test_run_handles_odd_sized_and_empty_document_lists(self):
odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")]
runner = ReorderRunner()
odd_reordered = runner.run(odd_documents)
assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"]
assert runner.run([]) == []

View File

@ -0,0 +1,414 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.rag.datasource.keyword.jieba.jieba as jieba_module
from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default
from core.rag.models.document import Document
class _DummyLock:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class _Field:
def __init__(self, name: str):
self._name = name
def __eq__(self, other):
return ("eq", self._name, other)
def in_(self, values):
return ("in", self._name, tuple(values))
class _FakeQuery:
def __init__(self):
self.where_calls: list[tuple] = []
def where(self, *conditions):
self.where_calls.append(conditions)
return self
class _FakeExecuteResult:
def __init__(self, segments: list[SimpleNamespace]):
self._segments = segments
def scalars(self):
return self
def all(self):
return self._segments
class _FakeSelect:
def __init__(self):
self.where_conditions: tuple | None = None
def where(self, *conditions):
self.where_conditions = conditions
return self
def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None):
return SimpleNamespace(
data_source_type=data_source_type,
keyword_table_dict=keyword_table_dict,
keyword_table="",
)
def _dataset(dataset_keyword_table=None, keyword_number=None):
return SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
keyword_number=keyword_number,
dataset_keyword_table=dataset_keyword_table,
)
@pytest.fixture
def patched_runtime(monkeypatch):
session = MagicMock()
db = SimpleNamespace(session=session)
storage = MagicMock()
lock = MagicMock(return_value=_DummyLock())
redis_client = SimpleNamespace(lock=lock)
monkeypatch.setattr(jieba_module, "db", db)
monkeypatch.setattr(jieba_module, "storage", storage)
monkeypatch.setattr(jieba_module, "redis_client", redis_client)
return SimpleNamespace(session=session, storage=storage, lock=lock)
def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime):
dataset = _dataset(_dataset_keyword_table(), keyword_number=2)
keyword = Jieba(dataset)
handler = MagicMock()
handler.extract_keywords.return_value = {"kw1", "kw2"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
result = keyword.create(
[
Document(page_content="alpha", metadata={"doc_id": "node-1"}),
SimpleNamespace(page_content="ignored", metadata=None),
]
)
assert result is keyword
keyword._update_segment_keywords.assert_called_once()
call_args = keyword._update_segment_keywords.call_args.args
assert call_args[0] == "dataset-1"
assert call_args[1] == "node-1"
assert set(call_args[2]) == {"kw1", "kw2"}
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
assert saved_table["kw1"] == {"node-1"}
assert saved_table["kw2"] == {"node-1"}
patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600)
def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3))
handler = MagicMock()
handler.extract_keywords.return_value = {"auto"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
texts = [
Document(page_content="extract-this", metadata={"doc_id": "node-1"}),
Document(page_content="use-manual", metadata={"doc_id": "node-2"}),
]
keyword.add_texts(texts, keywords_list=[[], ["manual"]])
assert keyword._update_segment_keywords.call_count == 2
first_call = keyword._update_segment_keywords.call_args_list[0].args
second_call = keyword._update_segment_keywords.call_args_list[1].args
assert set(first_call[2]) == {"auto"}
assert second_call[2] == ["manual"]
keyword._save_dataset_keyword_table.assert_called_once()
def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1))
handler = MagicMock()
handler.extract_keywords.return_value = {"from-extractor"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})])
handler.extract_keywords.assert_called_once_with("content", 1)
assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"}
def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
assert keyword.text_exists("node-1") is False
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
assert keyword.text_exists("node-2") is True
assert keyword.text_exists("node-x") is False
def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}}))
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.delete_by_ids(["node-1"])
keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"])
keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}})
def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.delete_by_ids(["node-1"])
keyword._delete_ids_from_keyword_table.assert_not_called()
keyword._save_dataset_keyword_table.assert_called_once_with(None)
def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime):
class _FakeDocumentSegment:
dataset_id = _Field("dataset_id")
index_node_id = _Field("index_node_id")
document_id = _Field("document_id")
keyword = Jieba(_dataset(_dataset_keyword_table()))
query_stmt = _FakeQuery()
patched_runtime.session.query.return_value = query_stmt
patched_runtime.session.execute.return_value = _FakeExecuteResult(
[
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
)
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
assert len(query_stmt.where_calls) == 2
assert len(documents) == 1
assert documents[0].page_content == "segment-content"
assert documents[0].metadata["doc_id"] == "node-2"
assert documents[0].metadata["doc_hash"] == "hash-2"
def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime):
db_keyword = _dataset_keyword_table(data_source_type="database")
file_keyword = _dataset_keyword_table(data_source_type="object_storage")
keyword_db = Jieba(_dataset(db_keyword))
keyword_db.delete()
patched_runtime.storage.delete.assert_not_called()
keyword_file = Jieba(_dataset(file_keyword))
keyword_file.delete()
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
assert patched_runtime.session.delete.call_count == 2
assert patched_runtime.session.commit.call_count == 2
def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime):
dataset_keyword_table = _dataset_keyword_table(data_source_type="database")
keyword = Jieba(_dataset(dataset_keyword_table))
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table
assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table
patched_runtime.session.commit.assert_called_once()
def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime):
dataset_keyword_table = _dataset_keyword_table(data_source_type="file")
keyword = Jieba(_dataset(dataset_keyword_table))
patched_runtime.storage.exists.return_value = True
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
patched_runtime.storage.save.assert_called_once()
save_args = patched_runtime.storage.save.call_args.args
assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt"
assert isinstance(save_args[1], bytes)
def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime):
existing = _dataset_keyword_table(
keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}}
)
keyword = Jieba(_dataset(existing))
assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]}
missing_payload = _dataset_keyword_table(keyword_table_dict=None)
keyword_with_missing_payload = Jieba(_dataset(missing_payload))
assert keyword_with_missing_payload._get_dataset_keyword_table() == {}
def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime):
created_tables: list[SimpleNamespace] = []
def _fake_dataset_keyword_table(**kwargs):
kwargs.setdefault("keyword_table", "")
kwargs.setdefault("keyword_table_dict", None)
table = SimpleNamespace(**kwargs)
created_tables.append(table)
return table
keyword = Jieba(_dataset(dataset_keyword_table=None))
monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table)
monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database")
result = keyword._get_dataset_keyword_table()
assert result == {}
assert len(created_tables) == 1
assert created_tables[0].dataset_id == "dataset-1"
assert created_tables[0].data_source_type == "database"
assert '"index_id":"dataset-1"' in created_tables[0].keyword_table
patched_runtime.session.add.assert_called_once_with(created_tables[0])
patched_runtime.session.commit.assert_called_once()
def test_add_and_delete_ids_from_keyword_table_helpers():
keyword = Jieba(_dataset(_dataset_keyword_table()))
keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}}
updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"])
assert updated["kw1"] == {"node-1", "node-3"}
assert updated["kw3"] == {"node-3"}
deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"])
assert "kw3" not in deleted
assert "kw1" not in deleted
assert deleted["kw2"] == {"node-2"}
def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
handler = MagicMock()
handler.extract_keywords.return_value = ["kw-a", "kw-b"]
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
ranked_ids = keyword._retrieve_ids_by_query(
{"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}},
"query",
k=1,
)
assert ranked_ids == ["node-2"]
def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime):
class _FakeDocumentSegment:
dataset_id = _Field("dataset_id")
index_node_id = _Field("index_node_id")
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
keyword = Jieba(_dataset(_dataset_keyword_table()))
segment = SimpleNamespace(keywords=[])
patched_runtime.session.scalar.return_value = segment
keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"])
assert segment.keywords == ["kw1", "kw2"]
patched_runtime.session.add.assert_called_once_with(segment)
patched_runtime.session.commit.assert_called_once()
patched_runtime.session.reset_mock()
patched_runtime.session.scalar.return_value = None
keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"])
patched_runtime.session.add.assert_not_called()
patched_runtime.session.commit.assert_not_called()
def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.create_segment_keywords("node-1", ["kw"])
keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"])
keyword._save_dataset_keyword_table.assert_called_once()
keyword._save_dataset_keyword_table.reset_mock()
keyword.update_segment_keywords_index("node-2", ["kw2"])
keyword._save_dataset_keyword_table.assert_called_once()
def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2))
handler = MagicMock()
handler.extract_keywords.return_value = {"auto"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None)
second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None)
keyword.multi_create_segment_keywords(
[
{"segment": first_segment, "keywords": ["manual"]},
{"segment": second_segment, "keywords": []},
]
)
assert first_segment.keywords == ["manual"]
assert second_segment.keywords == ["auto"]
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
assert saved_table["manual"] == {"node-1"}
assert saved_table["auto"] == {"node-2"}
def test_set_orjson_default_and_dumps_with_sets():
assert set(set_orjson_default({"a", "b"})) == {"a", "b"}
with pytest.raises(TypeError, match="is not JSON serializable"):
set_orjson_default(("not", "a", "set"))
payload = {"items": {"a", "b"}}
json_payload = dumps_with_sets(payload)
decoded = json.loads(json_payload)
assert set(decoded["items"]) == {"a", "b"}

View File

@ -0,0 +1,142 @@
import sys
import types
from types import SimpleNamespace
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class _DummyTFIDF:
def __init__(self):
self.stop_words = set()
@staticmethod
def extract_tags(sentence: str, top_k: int | None = 20, **kwargs):
return ["alpha_beta", "during", "gamma"]
def _install_fake_jieba_modules(
monkeypatch,
analyse_module: types.ModuleType,
jieba_attrs: dict[str, object] | None = None,
tfidf_module: types.ModuleType | None = None,
):
jieba_module = types.ModuleType("jieba")
jieba_module.__path__ = []
if jieba_attrs:
for key, value in jieba_attrs.items():
setattr(jieba_module, key, value)
jieba_module.analyse = analyse_module
analyse_module.__package__ = "jieba"
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module)
if tfidf_module is not None:
monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module)
else:
monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False)
def test_init_uses_existing_default_tfidf(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
default_tfidf = _DummyTFIDF()
analyse_module.default_tfidf = default_tfidf
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
assert handler._tfidf is default_tfidf
assert handler._tfidf.stop_words == STOPWORDS
def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
class _TFIDFFactory(_DummyTFIDF):
pass
analyse_module.TFIDF = _TFIDFFactory
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
assert isinstance(handler._tfidf, _TFIDFFactory)
assert analyse_module.default_tfidf is handler._tfidf
def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
tfidf_module = types.ModuleType("jieba.analyse.tfidf")
class _ImportedTFIDF(_DummyTFIDF):
pass
tfidf_module.TFIDF = _ImportedTFIDF
_install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module)
handler = JiebaKeywordTableHandler()
assert isinstance(handler._tfidf, _ImportedTFIDF)
assert analyse_module.default_tfidf is handler._tfidf
def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1)
assert fallback_keywords == ["two"]
def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
_install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]})
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
assert tfidf.extract_tags("ignored", topK=1) == ["x"]
def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
_install_fake_jieba_modules(
monkeypatch,
analyse_module,
jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])},
)
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
assert tfidf.extract_tags("ignored", topK=1) == ["foo"]
def test_extract_keywords_expands_subtokens():
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"])
keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3)
assert "alpha-beta" in keywords
assert "alpha" in keywords
assert "beta" in keywords
assert "during" in keywords
assert "gamma" in keywords
def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens():
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"})
assert "alpha-during-beta" in expanded
assert "alpha" in expanded
assert "beta" in expanded
assert "during" not in expanded

View File

@ -0,0 +1,6 @@
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
def test_stopwords_loaded():
assert "during" in STOPWORDS
assert "the" in STOPWORDS

View File

@ -0,0 +1,97 @@
from types import SimpleNamespace
import pytest
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
class _KeywordThatRaises(BaseKeyword):
def create(self, texts: list[Document], **kwargs):
return super().create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
return super().add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return super().text_exists(id)
def delete_by_ids(self, ids: list[str]):
return super().delete_by_ids(ids)
def delete(self):
return super().delete()
def search(self, query: str, **kwargs):
return super().search(query, **kwargs)
class _KeywordForHelpers(BaseKeyword):
def __init__(self, dataset, existing_ids: set[str] | None = None):
super().__init__(dataset)
self._existing_ids = existing_ids or set()
def create(self, texts: list[Document], **kwargs):
return self
def add_texts(self, texts: list[Document], **kwargs):
return None
def text_exists(self, id: str) -> bool:
return id in self._existing_ids
def delete_by_ids(self, ids: list[str]):
return None
def delete(self):
return None
def search(self, query: str, **kwargs):
return []
def test_abstract_methods_raise_not_implemented():
keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1"))
with pytest.raises(NotImplementedError):
keyword.create([])
with pytest.raises(NotImplementedError):
keyword.add_texts([])
with pytest.raises(NotImplementedError):
keyword.text_exists("doc-1")
with pytest.raises(NotImplementedError):
keyword.delete_by_ids(["doc-1"])
with pytest.raises(NotImplementedError):
keyword.delete()
with pytest.raises(NotImplementedError):
keyword.search("query")
def test_filter_duplicate_texts_removes_existing_doc_ids():
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"})
texts = [
Document(page_content="keep", metadata={"doc_id": "keep"}),
Document(page_content="duplicate", metadata={"doc_id": "duplicate"}),
SimpleNamespace(page_content="without-metadata", metadata=None),
]
filtered = keyword._filter_duplicate_texts(texts)
assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"]
assert any(text.metadata is None for text in filtered)
def test_get_uuids_returns_only_docs_with_metadata():
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"))
texts = [
Document(page_content="doc-1", metadata={"doc_id": "doc-1"}),
Document(page_content="doc-2", metadata={"doc_id": "doc-2"}),
SimpleNamespace(page_content="doc-3", metadata=None),
]
assert keyword._get_uuids(texts) == ["doc-1", "doc-2"]

View File

@ -0,0 +1,84 @@
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.keyword.keyword_type import KeyWordType
from core.rag.models.document import Document
def test_get_keyword_factory_returns_jieba_factory(monkeypatch):
fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba")
class FakeJieba:
pass
fake_module.Jieba = FakeJieba
monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module)
assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba
def test_get_keyword_factory_raises_for_unsupported_type():
with pytest.raises(ValueError, match="Keyword store unsupported is not supported"):
Keyword.get_keyword_factory("unsupported")
def test_keyword_initialization_uses_configured_factory(monkeypatch):
dataset = SimpleNamespace(id="dataset-1")
fake_processor = MagicMock()
monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA)
monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor))
keyword = Keyword(dataset)
assert keyword._keyword_processor is fake_processor
def test_keyword_methods_forward_to_processor():
processor = MagicMock()
processor.text_exists.return_value = True
processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})]
keyword = Keyword.__new__(Keyword)
keyword._keyword_processor = processor
docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})]
keyword.create(docs, foo="bar")
keyword.add_texts(docs, batch=True)
assert keyword.text_exists("doc-1") is True
keyword.delete_by_ids(["doc-1"])
keyword.delete()
assert keyword.search("query", top_k=1) == processor.search.return_value
processor.create.assert_called_once_with(docs, foo="bar")
processor.add_texts.assert_called_once_with(docs, batch=True)
processor.text_exists.assert_called_once_with("doc-1")
processor.delete_by_ids.assert_called_once_with(["doc-1"])
processor.delete.assert_called_once()
processor.search.assert_called_once_with("query", top_k=1)
def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes():
class Processor:
value = 1
@staticmethod
def custom():
return "ok"
keyword = Keyword.__new__(Keyword)
keyword._keyword_processor = Processor()
assert keyword.custom() == "ok"
with pytest.raises(AttributeError):
_ = keyword.value
keyword._keyword_processor = None
with pytest.raises(AttributeError):
_ = keyword.missing_method

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
def test_validate_distance_function_accepts_supported_values():
factory = AlibabaCloudMySQLVectorFactory()
assert factory._validate_distance_function("cosine") == "cosine"
assert factory._validate_distance_function("euclidean") == "euclidean"
def test_validate_distance_function_rejects_unsupported_values():
factory = AlibabaCloudMySQLVectorFactory()
with pytest.raises(ValueError, match="Invalid distance function"):
factory._validate_distance_function("dot_product")
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
index_struct=None,
)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6)
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-2",
index_struct_dict=None,
index_struct=None,
)
monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12)
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
vector_cls.assert_called_once()
assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2"
assert dataset.index_struct is not None

View File

@ -0,0 +1,133 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from core.rag.models.document import Document
def test_init_prefers_openapi_when_api_config_is_provided():
api_config = AnalyticdbVectorOpenAPIConfig(
access_key_id="ak",
access_key_secret="sk",
region_id="cn-hangzhou",
instance_id="instance-1",
account="account",
account_password="password",
namespace="dify",
namespace_password="ns-password",
)
with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls:
vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None)
assert vector.analyticdb_vector == "openapi_runner"
openapi_cls.assert_called_once_with("COLLECTION", api_config)
def test_init_uses_sql_implementation_when_api_config_is_missing():
sql_config = AnalyticdbVectorBySqlConfig(
host="localhost",
port=5432,
account="account",
account_password="password",
min_connection=1,
max_connection=2,
namespace="dify",
)
with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls:
vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config)
assert vector.analyticdb_vector == "sql_runner"
sql_cls.assert_called_once_with("COLLECTION", sql_config)
def test_init_raises_when_both_configs_are_missing():
with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"):
AnalyticdbVector("COLLECTION", api_config=None, sql_config=None)
def test_vector_methods_delegate_to_underlying_implementation():
runner = MagicMock()
runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})]
runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})]
runner.text_exists.return_value = True
vector = AnalyticdbVector.__new__(AnalyticdbVector)
vector.analyticdb_vector = runner
texts = [Document(page_content="hello", metadata={"doc_id": "d1"})]
vector.create(texts=texts, embeddings=[[0.1, 0.2]])
vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]])
assert vector.text_exists("d1") is True
vector.delete_by_ids(["d1"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
vector.delete()
runner._create_collection_if_not_exists.assert_called_once_with(2)
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
runner.delete_by_ids.assert_called_once_with(["d1"])
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")
runner.delete.assert_called_once()
def test_get_type_is_analyticdb():
vector = AnalyticdbVector.__new__(AnalyticdbVector)
assert vector.get_type() == "analyticdb"
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password")
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
args = vector_cls.call_args.args
assert args[0] == "auto_collection"
assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig)
assert args[2] is None
assert dataset.index_struct is not None
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
args = vector_cls.call_args.args
assert args[0] == "existing"
assert args[1] is None
assert isinstance(args[2], AnalyticdbVectorBySqlConfig)

View File

@ -0,0 +1,384 @@
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.models.document import Document
def _request_class(name: str):
class _Request:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
_Request.__name__ = name
return _Request
def _install_openapi_stubs(monkeypatch):
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
gpdb_package.__path__ = []
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
for class_name in [
"InitVectorDatabaseRequest",
"DescribeNamespaceRequest",
"CreateNamespaceRequest",
"DescribeCollectionRequest",
"CreateCollectionRequest",
"UpsertCollectionDataRequestRows",
"UpsertCollectionDataRequest",
"QueryCollectionDataRequest",
"DeleteCollectionDataRequest",
"DeleteCollectionRequest",
]:
setattr(gpdb_models, class_name, _request_class(class_name))
class _Client:
def __init__(self, config):
self.config = config
gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client")
gpdb_client.Client = _Client
gpdb_package.models = gpdb_models
tea_openapi = types.ModuleType("alibabacloud_tea_openapi")
tea_openapi.__path__ = []
tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models")
class OpenApiConfig:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
tea_openapi_models.Config = OpenApiConfig
tea_openapi.models = tea_openapi_models
tea_package = types.ModuleType("Tea")
tea_package.__path__ = []
tea_exceptions = types.ModuleType("Tea.exceptions")
class TeaError(Exception):
def __init__(self, status_code=None, **kwargs):
super().__init__("TeaException")
status_code = kwargs.get("statusCode", status_code)
self.statusCode = status_code
self.status_code = status_code
tea_exceptions.TeaException = TeaError
tea_package.exceptions = tea_exceptions
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package)
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models)
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client)
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi)
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models)
monkeypatch.setitem(sys.modules, "Tea", tea_package)
monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions)
return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig)
def _config() -> AnalyticdbVectorOpenAPIConfig:
return AnalyticdbVectorOpenAPIConfig(
access_key_id="ak",
access_key_secret="sk",
region_id="cn-hangzhou",
instance_id="instance-1",
account="account",
account_password="password",
namespace="dify",
namespace_password="ns-password",
)
@pytest.mark.parametrize(
("field", "value", "error_message"),
[
("access_key_id", "", "ANALYTICDB_KEY_ID"),
("access_key_secret", "", "ANALYTICDB_KEY_SECRET"),
("region_id", "", "ANALYTICDB_REGION_ID"),
("instance_id", "", "ANALYTICDB_INSTANCE_ID"),
("account", "", "ANALYTICDB_ACCOUNT"),
("account_password", "", "ANALYTICDB_PASSWORD"),
("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"),
],
)
def test_openapi_config_validation(field, value, error_message):
values = _config().model_dump()
values[field] = value
with pytest.raises(ValueError, match=error_message):
AnalyticdbVectorOpenAPIConfig.model_validate(values)
def test_openapi_config_to_client_params():
config = _config()
params = config.to_analyticdb_client_params()
assert params["access_key_id"] == "ak"
assert params["access_key_secret"] == "sk"
assert params["region_id"] == "cn-hangzhou"
assert params["read_timeout"] == 60000
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
initialize_mock = MagicMock()
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config())
assert vector._collection_name == "collection_1"
assert isinstance(vector._client_config, stubs.OpenApiConfig)
assert vector._client_config.user_agent == "dify"
assert vector._client_config.access_key_id == "ak"
assert vector._client.config is vector._client_config
initialize_mock.assert_called_once_with()
def test_initialize_skips_when_cached(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._initialize_vector_database = MagicMock()
vector._create_namespace_if_not_exists = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_not_called()
vector._create_namespace_if_not_exists.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._initialize_vector_database = MagicMock()
vector._create_namespace_if_not_exists = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_called_once()
vector._create_namespace_if_not_exists.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._initialize_vector_database()
request = vector._client.init_vector_database.call_args.args[0]
assert request.dbinstance_id == "instance-1"
assert request.region_id == "cn-hangzhou"
assert request.manager_account == "account"
assert request.manager_account_password == "password"
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404)
vector._create_namespace_if_not_exists()
vector._client.create_namespace.assert_called_once()
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create namespace"):
vector._create_namespace_if_not_exists()
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._create_namespace_if_not_exists()
vector._client.describe_namespace.assert_called_once()
vector._client.create_namespace.assert_not_called()
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector._client.create_collection.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector._client.describe_collection.assert_not_called()
vector._client.create_collection.assert_not_called()
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create collection collection_1"):
vector._create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
documents = [
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
SimpleNamespace(page_content="doc 2", metadata=None),
]
embeddings = [[0.1, 0.2], [0.2, 0.3]]
vector.add_texts(documents, embeddings)
upsert_request = vector._client.upsert_collection_data.call_args.args[0]
assert upsert_request.collection == "collection_1"
assert len(upsert_request.rows) == 1
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()]))
)
assert vector.text_exists("d1") is True
vector.delete_by_ids(["d1", "d2"])
request = vector._client.delete_collection_data.call_args.args[0]
assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')"
vector.delete_by_metadata_field("document_id", "doc-1")
request = vector._client.delete_collection_data.call_args.args[0]
assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'"
match_high = SimpleNamespace(
score=0.9,
metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"},
values=SimpleNamespace(value=[1.0, 2.0]),
)
match_low = SimpleNamespace(
score=0.1,
metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"},
values=SimpleNamespace(value=[3.0, 4.0]),
)
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high]))
)
docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs_by_vector) == 1
assert docs_by_vector[0].page_content == "high"
assert docs_by_vector[0].metadata["score"] == 0.9
docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2)
assert len(docs_by_text) == 1
assert docs_by_text[0].page_content == "high"
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[]))
)
assert vector.text_exists("missing-id") is False
def test_openapi_delete_success(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector.delete()
vector._client.delete_collection.assert_called_once()
def test_openapi_delete_propagates_errors(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.delete_collection.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.delete()

View File

@ -0,0 +1,427 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import psycopg2.errors
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
AnalyticdbVectorBySql,
AnalyticdbVectorBySqlConfig,
)
from core.rag.models.document import Document
def _config_values() -> dict:
return {
"host": "localhost",
"port": 5432,
"account": "account",
"account_password": "password",
"min_connection": 1,
"max_connection": 2,
"namespace": "dify",
}
@pytest.mark.parametrize(
("field", "value", "error_message"),
[
("host", "", "ANALYTICDB_HOST"),
("port", 0, "ANALYTICDB_PORT"),
("account", "", "ANALYTICDB_ACCOUNT"),
("account_password", "", "ANALYTICDB_PASSWORD"),
("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"),
("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"),
],
)
def test_sql_config_required_fields(field, value, error_message):
values = _config_values()
values[field] = value
with pytest.raises(ValueError, match=error_message):
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_sql_config_rejects_min_connection_greater_than_max_connection():
values = _config_values()
values["min_connection"] = 10
values["max_connection"] = 2
with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"):
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_initialize_skips_when_cache_exists(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._initialize_vector_database = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._initialize_vector_database = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_called_once()
sql_module.redis_client.set.assert_called_once()
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
pool_instance = MagicMock()
monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance))
pool = vector._create_connection_pool()
assert pool is pool_instance
sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once()
def test_get_cursor_context_manager_handles_connection_lifecycle():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
cursor = MagicMock()
connection = MagicMock()
connection.cursor.return_value = cursor
pool = MagicMock()
pool.getconn.return_value = connection
vector.pool = pool
with vector._get_cursor() as cur:
assert cur is cursor
cursor.close.assert_called_once()
connection.commit.assert_called_once()
pool.putconn.assert_called_once_with(connection)
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id")
monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock())
docs = [
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
SimpleNamespace(page_content="doc 2", metadata=None),
]
vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]])
execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args
assert execute_args[0] is cursor
assert len(execute_args[2]) == 1
def test_text_exists_returns_true_and_false_based_on_query_result():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.fetchone.return_value = ("row",)
assert vector.text_exists("d1") is True
cursor.fetchone.return_value = None
assert vector.text_exists("d1") is False
def test_delete_by_ids_handles_empty_input_and_missing_table_error():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector.delete_by_ids([])
cursor.execute.assert_not_called()
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
vector.delete_by_ids(["d1"])
def test_delete_by_metadata_field_handles_missing_table_error():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
vector.delete_by_metadata_field("document_id", "doc-1")
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
def test_search_by_vector_validates_top_k(invalid_top_k):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k)
def test_search_by_vector_returns_documents_above_threshold():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}),
("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}),
]
)
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "content 1"
assert docs[0].metadata["score"] == 0.8
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
def test_search_by_full_text_validates_top_k(invalid_top_k):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("query", top_k=invalid_top_k)
def test_search_by_full_text_returns_documents():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9),
]
)
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
assert docs[0].page_content == "content 1"
def test_delete_drops_table():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector.delete()
cursor.execute.assert_called_once()
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
config = AnalyticdbVectorBySqlConfig(**_config_values())
created_pool = MagicMock()
monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock())
monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool))
vector = AnalyticdbVectorBySql("My_Collection", config)
assert vector._collection_name == "my_collection"
assert vector.table_name == "dify.my_collection"
assert vector.databaseName == "knowledgebase"
assert vector.pool is created_pool
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
bootstrap_cursor = MagicMock()
bootstrap_connection = MagicMock()
bootstrap_connection.cursor.return_value = bootstrap_cursor
bootstrap_cursor.execute.side_effect = RuntimeError("database already exists")
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
worker_cursor = MagicMock()
worker_connection = MagicMock()
worker_cursor.connection = worker_connection
def _execute(sql, *args, **kwargs):
if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql:
raise RuntimeError("already exists")
worker_cursor.execute.side_effect = _execute
pooled_connection = MagicMock()
pooled_connection.cursor.return_value = worker_cursor
pool = MagicMock()
pool.getconn.return_value = pooled_connection
vector._create_connection_pool = MagicMock(return_value=pool)
vector._initialize_vector_database()
bootstrap_cursor.close.assert_called_once()
bootstrap_connection.close.assert_called_once()
vector._create_connection_pool.assert_called_once()
assert any(
"CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0]
for call in worker_cursor.execute.call_args_list
)
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
bootstrap_cursor = MagicMock()
bootstrap_connection = MagicMock()
bootstrap_connection.cursor.return_value = bootstrap_cursor
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
worker_cursor = MagicMock()
worker_connection = MagicMock()
worker_cursor.connection = worker_connection
worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable")
pooled_connection = MagicMock()
pooled_connection.cursor.return_value = worker_cursor
pool = MagicMock()
pool.getconn.return_value = pooled_connection
vector._create_connection_pool = MagicMock(return_value=pool)
with pytest.raises(RuntimeError, match="Failed to create zhparser extension"):
vector._initialize_vector_database()
worker_connection.rollback.assert_called_once()
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
vector.table_name = "dify.collection"
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector._create_collection_if_not_exists(embedding_dimension=3)
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
sql_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
vector.table_name = "dify.collection"
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
cursor = MagicMock()
cursor.execute.side_effect = RuntimeError("permission denied")
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
with pytest.raises(RuntimeError, match="permission denied"):
vector._create_collection_if_not_exists(embedding_dimension=3)
def test_delete_methods_raise_when_error_is_not_missing_table():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.execute.side_effect = RuntimeError("unexpected delete failure")
with pytest.raises(RuntimeError, match="unexpected delete failure"):
vector.delete_by_ids(["doc-1"])
cursor.execute.side_effect = RuntimeError("unexpected metadata failure")
with pytest.raises(RuntimeError, match="unexpected metadata failure"):
vector.delete_by_metadata_field("document_id", "doc-1")

View File

@ -0,0 +1,542 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_pymochow_modules():
pymochow = types.ModuleType("pymochow")
pymochow.__path__ = []
pymochow_auth = types.ModuleType("pymochow.auth")
pymochow_auth.__path__ = []
pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials")
pymochow_configuration = types.ModuleType("pymochow.configuration")
pymochow_exception = types.ModuleType("pymochow.exception")
pymochow_model = types.ModuleType("pymochow.model")
pymochow_model.__path__ = []
pymochow_model_database = types.ModuleType("pymochow.model.database")
pymochow_model_enum = types.ModuleType("pymochow.model.enum")
pymochow_model_schema = types.ModuleType("pymochow.model.schema")
pymochow_model_table = types.ModuleType("pymochow.model.table")
class _SimpleObject:
def __init__(self, *args, **kwargs):
self.args = args
for key, value in kwargs.items():
setattr(self, key, value)
class ServerError(Exception):
def __init__(self, code):
super().__init__(f"server error {code}")
self.code = code
class ServerErrCode:
TABLE_NOT_EXIST = 1001
DB_ALREADY_EXIST = 1002
class IndexType:
__members__ = {"HNSW": "HNSW"}
class MetricType:
__members__ = {"IP": "IP"}
class IndexState:
NORMAL = "NORMAL"
class TableState:
NORMAL = "NORMAL"
class InvertedIndexAnalyzer:
DEFAULT_ANALYZER = "DEFAULT_ANALYZER"
class InvertedIndexParseMode:
COARSE_MODE = "COARSE_MODE"
class InvertedIndexFieldAttribute:
ANALYZED = "ANALYZED"
class FieldType:
STRING = "STRING"
TEXT = "TEXT"
JSON = "JSON"
FLOAT_VECTOR = "FLOAT_VECTOR"
pymochow.MochowClient = _SimpleObject
pymochow_credentials.BceCredentials = _SimpleObject
pymochow_configuration.Configuration = _SimpleObject
pymochow_exception.ServerError = ServerError
pymochow_model_database.Database = _SimpleObject
pymochow_model_enum.FieldType = FieldType
pymochow_model_enum.IndexState = IndexState
pymochow_model_enum.IndexType = IndexType
pymochow_model_enum.MetricType = MetricType
pymochow_model_enum.ServerErrCode = ServerErrCode
pymochow_model_enum.TableState = TableState
for cls_name in [
"AutoBuildRowCountIncrement",
"Field",
"FilteringIndex",
"HNSWParams",
"InvertedIndex",
"InvertedIndexParams",
"Schema",
"VectorIndex",
]:
setattr(pymochow_model_schema, cls_name, _SimpleObject)
pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer
pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute
pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode
for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]:
setattr(pymochow_model_table, cls_name, _SimpleObject)
pymochow.auth = pymochow_auth
pymochow.model = pymochow_model
pymochow_auth.bce_credentials = pymochow_credentials
pymochow_model.database = pymochow_model_database
pymochow_model.enum = pymochow_model_enum
pymochow_model.schema = pymochow_model_schema
pymochow_model.table = pymochow_model_table
modules = {
"pymochow": pymochow,
"pymochow.auth": pymochow_auth,
"pymochow.auth.bce_credentials": pymochow_credentials,
"pymochow.configuration": pymochow_configuration,
"pymochow.exception": pymochow_exception,
"pymochow.model": pymochow_model,
"pymochow.model.database": pymochow_model_database,
"pymochow.model.enum": pymochow_model_enum,
"pymochow.model.schema": pymochow_model_schema,
"pymochow.model.table": pymochow_model_table,
}
return modules
@pytest.fixture
def baidu_module(monkeypatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.baidu.baidu_vector as module
return importlib.reload(module)
def test_baidu_config_validation(baidu_module):
values = {
"endpoint": "https://example.com",
"account": "account",
"api_key": "key",
"database": "database",
}
config = baidu_module.BaiduConfig.model_validate(values)
assert config.endpoint == "https://example.com"
for key, error_message in [
("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"),
("account", "BAIDU_VECTOR_DB_ACCOUNT"),
("api_key", "BAIDU_VECTOR_DB_API_KEY"),
("database", "BAIDU_VECTOR_DB_DATABASE"),
]:
invalid = dict(values)
invalid[key] = ""
with pytest.raises(ValueError, match=error_message):
baidu_module.BaiduConfig.model_validate(invalid)
def test_get_search_result_handles_metadata_and_threshold(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
response = SimpleNamespace(
rows=[
{"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9},
{"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4},
{"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95},
]
)
docs = vector._get_search_res(response, score_threshold=0.8)
assert len(docs) == 2
assert docs[0].page_content == "doc1"
assert docs[0].metadata["score"] == 0.9
assert docs[1].page_content == "doc3"
def test_delete_by_ids_and_delete_by_metadata_field(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
vector._collection_name = "collection_1"
vector.delete_by_ids([])
table.delete.assert_not_called()
vector.delete_by_ids(["id1", "id2"])
table.delete.assert_called_once()
table.delete.reset_mock()
vector.delete_by_metadata_field("source", 'abc"def')
delete_filter = table.delete.call_args.kwargs["filter"]
assert delete_filter == 'metadata["source"] = "abc\\"def"'
def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
vector.delete()
vector._db.drop_table.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector.delete()
def test_init_database_uses_existing_or_creates_when_missing(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(database="my_db")
vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")]
vector._client.database.return_value = "existing_db"
assert vector._init_database() == "existing_db"
vector._client.list_databases.return_value = []
vector._client.database.return_value = "created_db"
vector._client.create_database.side_effect = None
assert vector._init_database() == "created_db"
vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST)
assert vector._init_database() == "created_db"
def test_table_existed_checks_table_access(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._db.table.return_value = MagicMock()
assert vector._table_existed() is True
vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
assert vector._table_existed() is False
vector._db.table.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector._table_existed()
def test_search_methods_delegate_to_database_table(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})])
table = MagicMock()
vector._db.table.return_value = table
table.search.return_value = "vector_result"
table.bm25_search.return_value = "bm25_result"
result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
assert result1 == vector._get_search_res.return_value
assert result2 == vector._get_search_res.return_value
assert vector._get_search_res.call_count == 2
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection"
assert dataset.index_struct is not None
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
init_client = MagicMock(return_value="client")
init_database = MagicMock(return_value="database")
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database)
config = baidu_module.BaiduConfig(
endpoint="https://example.com",
account="account",
api_key="key",
database="db",
)
vector = baidu_module.BaiduVector(collection_name="my_collection", config=config)
assert vector.get_type() == baidu_module.VectorType.BAIDU
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection"
assert vector._client == "client"
assert vector._db == "database"
vector._create_table = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="p1", metadata={"doc_id": "d1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_table.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_batches_rows(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
docs = [
Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}),
Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}),
]
vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]])
assert table.upsert.call_count == 1
inserted_rows = table.upsert.call_args.kwargs["rows"]
assert len(inserted_rows) == 2
def test_add_texts_batches_more_than_batch_size(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
docs = [
Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"})
for idx in range(1001)
]
embeddings = [[0.1, 0.2] for _ in range(1001)]
vector.add_texts(docs, embeddings)
assert table.upsert.call_count == 2
assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000
assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1
def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
table.query.return_value = SimpleNamespace(code=0)
assert vector.text_exists("id-1") is True
table.query.return_value = SimpleNamespace(code=1)
assert vector.text_exists("id-1") is False
table.query.return_value = None
assert vector.text_exists("id-1") is False
def test_get_search_result_handles_invalid_metadata_json(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}])
docs = vector._get_search_res(response, score_threshold=0.1)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.7
assert "document_id" not in docs[0].metadata
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
credentials = MagicMock(return_value="credentials")
configuration = MagicMock(return_value="configuration")
client_cls = MagicMock(return_value="client")
monkeypatch.setattr(baidu_module, "BceCredentials", credentials)
monkeypatch.setattr(baidu_module, "Configuration", configuration)
monkeypatch.setattr(baidu_module, "MochowClient", client_cls)
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint")
client = vector._init_client(config)
assert client == "client"
credentials.assert_called_once_with("account", "key")
configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint")
client_cls.assert_called_once_with("configuration")
def test_init_database_raises_for_unknown_create_database_error(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(database="my_db")
vector._client.list_databases.return_value = []
vector._client.create_database.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector._init_database()
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
index_type="HNSW",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
vector._db = MagicMock()
table = MagicMock()
table.state = baidu_module.TableState.NORMAL
vector._db.describe_table.return_value = table
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock())
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock())
# Cached table skips all work.
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1))
vector._create_table(3)
vector._db.create_table.assert_not_called()
# Existing table also skips creation.
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
vector._table_existed.return_value = True
vector._create_table(3)
vector._db.create_table.assert_not_called()
# Create table when cache is empty and table does not exist.
vector._table_existed.return_value = False
vector._create_table(3)
vector._db.create_table.assert_called_once()
baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600)
table.rebuild_index.assert_called_once_with(vector.vector_index)
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
vector._client_config = SimpleNamespace(
index_type="INVALID",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
with pytest.raises(ValueError, match="unsupported index_type"):
vector._create_table(3)
vector._client_config.index_type = "HNSW"
vector._client_config.metric_type = "INVALID"
with pytest.raises(ValueError, match="unsupported metric_type"):
vector._create_table(3)
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
index_type="HNSW",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
vector._db = MagicMock()
vector._db.describe_table.return_value = SimpleNamespace(state="CREATING")
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301]))
with pytest.raises(TimeoutError, match="Table creation timeout"):
vector._create_table(3)
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"

View File

@ -0,0 +1,199 @@
import importlib
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_chroma_modules():
chroma = types.ModuleType("chromadb")
chroma.DEFAULT_TENANT = "default_tenant"
chroma.DEFAULT_DATABASE = "default_database"
class Settings:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class QueryResult(UserDict):
pass
class _Collection:
def __init__(self):
self.upsert = MagicMock()
self.delete = MagicMock()
self.query = MagicMock()
self.get = MagicMock(return_value={})
class _Client:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.collection = _Collection()
self.get_or_create_collection = MagicMock(return_value=self.collection)
self.delete_collection = MagicMock()
chroma.Settings = Settings
chroma.QueryResult = QueryResult
chroma.HttpClient = _Client
return chroma
@pytest.fixture
def chroma_module(monkeypatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import core.rag.datasource.vdb.chroma.chroma_vector as module
return importlib.reload(module)
def test_chroma_config_to_params_builds_expected_payload(chroma_module):
config = chroma_module.ChromaConfig(
host="localhost",
port=8000,
tenant="tenant-1",
database="db-1",
auth_provider="provider",
auth_credentials="credentials",
)
params = config.to_chroma_params()
assert params["host"] == "localhost"
assert params["port"] == 8000
assert params["tenant"] == "tenant-1"
assert params["database"] == "db-1"
assert params["ssl"] is False
assert params["settings"].chroma_client_auth_provider == "provider"
assert params["settings"].chroma_client_auth_credentials == "credentials"
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock())
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.create_collection("collection_1")
vector._client.get_or_create_collection.assert_called_once_with("collection_1")
chroma_module.redis_client.set.assert_called_once()
def test_create_with_empty_texts_is_noop(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.create([], [])
vector._client.get_or_create_collection.assert_not_called()
def test_create_with_texts_creates_collection_and_upserts(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._client.get_or_create_collection.assert_called()
vector._client.collection.upsert.assert_called_once()
def test_delete_methods_and_text_exists(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.delete_by_ids([])
vector._client.collection.delete.assert_not_called()
vector.delete_by_ids(["id-1"])
vector._client.collection.delete.assert_called_with(ids=["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}})
vector._client.collection.get.return_value = {"ids": ["id-1"]}
assert vector.text_exists("id-1") is True
vector._client.collection.get.return_value = {}
assert vector.text_exists("id-2") is False
vector.delete()
vector._client.delete_collection.assert_called_once_with("collection_1")
def test_search_by_vector_handles_empty_results(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []}
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector._client.collection.query.return_value = {
"ids": [["id-1", "id-2"]],
"documents": [["doc high", "doc low"]],
"metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]],
"distances": [[0.1, 0.8]],
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc high"
assert docs[0].metadata["score"] == 0.9
def test_search_by_full_text_returns_empty_list(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
assert vector.search_by_full_text("query") == []
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
factory = chroma_module.ChromaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost")
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None)
with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,927 @@
import importlib
import queue
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_clickzetta_module():
clickzetta = types.ModuleType("clickzetta")
class _FakeCursor:
def __init__(self):
self.execute = MagicMock()
self.executemany = MagicMock()
self.fetchall = MagicMock(return_value=[])
self.fetchone = MagicMock(return_value=(0,))
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class _FakeConnection:
def __init__(self):
self.cursor_obj = _FakeCursor()
def cursor(self):
return self.cursor_obj
def close(self):
return None
def connect(**_kwargs):
return _FakeConnection()
clickzetta.connect = connect
return clickzetta
@pytest.fixture
def clickzetta_module(monkeypatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
return importlib.reload(module)
def _config(module):
return module.ClickzettaConfig(
username="username",
password="password",
instance="instance",
service="service",
workspace="workspace",
vcluster="cluster",
schema_name="dify",
)
@pytest.mark.parametrize(
("field", "error_message"),
[
("username", "CLICKZETTA_USERNAME"),
("password", "CLICKZETTA_PASSWORD"),
("instance", "CLICKZETTA_INSTANCE"),
("service", "CLICKZETTA_SERVICE"),
("workspace", "CLICKZETTA_WORKSPACE"),
("vcluster", "CLICKZETTA_VCLUSTER"),
("schema_name", "CLICKZETTA_SCHEMA"),
],
)
def test_clickzetta_config_validation(clickzetta_module, field, error_message):
values = _config(clickzetta_module).model_dump()
values[field] = ""
with pytest.raises(ValueError, match=error_message):
clickzetta_module.ClickzettaConfig.model_validate(values)
def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1")
assert parsed["doc_id"] == "row-1"
assert parsed["document_id"] == "doc-1"
parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2")
assert parsed_double["doc_id"] == "row-2"
assert parsed_double["document_id"] == "doc-2"
parsed_fallback = vector._parse_metadata("not-json", "row-3")
assert parsed_fallback["doc_id"] == "row-3"
assert parsed_fallback["document_id"] == "row-3"
def test_safe_doc_id_and_vector_format_helpers(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3"
assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF"
assert vector._safe_doc_id("ab c;\n") == "abc"
assert len(vector._safe_doc_id("a" * 300)) == 255
def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
@contextmanager
def _ctx_not_found():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_not_found
assert vector._table_exists() is False
@contextmanager
def _ctx_other_error():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("permission denied")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_other_error
assert vector._table_exists() is False
def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.text_exists("doc-1") is False
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchone.return_value = (1,)
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
assert vector.text_exists("doc-1") is True
def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector.delete_by_ids([])
vector._execute_write.assert_not_called()
vector._table_exists = MagicMock(return_value=False)
vector.delete_by_ids(["doc-1"])
vector._execute_write.assert_not_called()
vector.delete_by_metadata_field("document_id", "doc-1")
vector._execute_write.assert_not_called()
def test_search_short_circuit_behaviors(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
vector._config.enable_inverted_index = False
assert vector.search_by_full_text("query", top_k=2) == []
def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "content"
assert docs[0].metadata["score"] == 0.5
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
factory = clickzetta_module.ClickzettaVectorFactory()
dataset = SimpleNamespace(id="dataset-1")
monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10)
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True)
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance")
with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
clickzetta_module.ClickzettaConnectionPool._instance = None
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance()
pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance()
key = pool_1._get_config_key(_config(clickzetta_module))
assert pool_1 is pool_2
assert "username:instance:service:workspace:cluster:dify" in key
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
connection = MagicMock()
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(
clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection])
)
pool._configure_connection = MagicMock()
created = pool._create_connection(config)
assert created is connection
assert clickzetta_module.clickzetta.connect.call_count == 2
pool._configure_connection.assert_called_once_with(connection)
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
pool._create_connection(config)
def test_connection_pool_configure_and_validate_connection(clickzetta_module):
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection = MagicMock()
connection.cursor.return_value = cursor
pool._configure_connection(connection)
assert cursor.execute.call_count >= 2
assert pool._is_connection_valid(connection) is True
bad_connection = MagicMock()
bad_connection.cursor.side_effect = RuntimeError("bad connection")
assert pool._is_connection_valid(bad_connection) is False
monkeypatch.undo()
def test_connection_pool_configure_connection_swallows_errors(clickzetta_module):
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
connection = MagicMock()
connection.cursor.side_effect = RuntimeError("cannot configure")
pool._configure_connection(connection)
monkeypatch.undo()
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
key = pool._get_config_key(config)
created_connection = MagicMock()
pool._create_connection = MagicMock(return_value=created_connection)
first = pool.get_connection(config)
assert first is created_connection
reusable_connection = MagicMock()
pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())]
pool._is_connection_valid = MagicMock(return_value=True)
reused = pool.get_connection(config)
assert reused is reusable_connection
expired_connection = MagicMock()
pool._pools[key] = [(expired_connection, 0.0)]
pool._is_connection_valid = MagicMock(return_value=False)
monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0))
pool.get_connection(config)
expired_connection.close.assert_called_once()
random_connection = MagicMock()
pool._is_connection_valid = MagicMock(return_value=True)
pool.return_connection(config, random_connection)
assert len(pool._pools[key]) == 1
pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)]
pool._connection_timeout = 10
pool._cleanup_expired_connections()
assert len(pool._pools[key]) == 1
unknown_pool = MagicMock()
pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool)
unknown_pool.close.assert_called_once()
pool.shutdown()
assert pool._shutdown is True
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
class _Thread:
def __init__(self, target, daemon):
self._target = target
self.daemon = daemon
self.started = False
def start(self):
self.started = True
self._target()
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
pool._start_cleanup_thread()
assert pool._cleanup_thread.started is True
pool._cleanup_expired_connections.assert_called_once()
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
pool = MagicMock()
pool.get_connection.return_value = "conn"
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock())
vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module))
assert vector._table_name == "my_collection"
assert vector._get_connection() == "conn"
vector._return_connection("conn")
pool.return_connection.assert_called_with(vector._config, "conn")
with vector.get_connection_context() as conn:
assert conn == "conn"
assert pool.return_connection.call_count >= 2
assert vector.get_type() == "clickzetta"
assert vector._ensure_connection() == "conn"
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
class _Thread:
def __init__(self, target, daemon):
self.target = target
self.daemon = daemon
self.started = 0
def start(self):
self.started += 1
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
clickzetta_module.ClickzettaVector._write_queue = None
clickzetta_module.ClickzettaVector._write_thread = None
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._init_write_queue()
clickzetta_module.ClickzettaVector._init_write_queue()
assert clickzetta_module.ClickzettaVector._write_thread.started == 1
result_queue_ok = queue.Queue()
result_queue_fail = queue.Queue()
clickzetta_module.ClickzettaVector._write_queue = queue.Queue()
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok))
clickzetta_module.ClickzettaVector._write_queue.put(
(lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail)
)
clickzetta_module.ClickzettaVector._write_queue.put(None)
clickzetta_module.ClickzettaVector._write_worker()
assert result_queue_ok.get() == (True, 2)
failed = result_queue_fail.get()
assert failed[0] is False
assert isinstance(failed[1], RuntimeError)
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
clickzetta_module.ClickzettaVector._write_queue = None
with pytest.raises(RuntimeError, match="Write queue not initialized"):
vector._execute_write(lambda: None)
class _ImmediateSuccessQueue:
def put(self, task):
func, args, kwargs, result_q = task
result_q.put((True, func(*args, **kwargs)))
clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue()
assert vector._execute_write(lambda x: x * 2, 3) == 6
class _ImmediateFailQueue:
def put(self, task):
_, _, _, result_q = task
result_q.put((False, ValueError("write failed")))
clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue()
with pytest.raises(ValueError, match="write failed"):
vector._execute_write(lambda: None)
def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
@contextmanager
def _ctx_exists():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_exists
assert vector._table_exists() is True
vector._execute_write = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="content", metadata={"doc_id": "d1"})]
vector.create(docs, [[0.1, 0.2]])
vector._execute_write.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_table_and_indexes_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._create_vector_index = MagicMock()
vector._create_inverted_index = MagicMock()
vector._table_exists = MagicMock(return_value=True)
vector._create_table_and_indexes([[0.1, 0.2]])
vector._create_vector_index.assert_not_called()
vector._table_exists = MagicMock(return_value=False)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._create_table_and_indexes([[0.1, 0.2, 0.3]])
vector._create_vector_index.assert_called_once()
vector._create_inverted_index.assert_called_once()
vector._config.enable_inverted_index = False
vector._create_vector_index.reset_mock()
vector._create_inverted_index.reset_mock()
vector._create_table_and_indexes([])
vector._create_vector_index.assert_called_once()
vector._create_inverted_index.assert_not_called()
def test_create_vector_index_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")]
vector._create_vector_index(cursor)
assert cursor.execute.call_count == 1
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("show index failed"), None]
vector._create_vector_index(cursor)
assert cursor.execute.call_count == 2
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("already exists")]
cursor.fetchall.return_value = []
vector._create_vector_index(cursor)
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("unexpected")]
cursor.fetchall.return_value = []
with pytest.raises(RuntimeError, match="unexpected"):
vector._create_vector_index(cursor)
def test_create_inverted_index_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
vector._create_inverted_index(cursor)
assert cursor.execute.call_count == 1
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("show failed"), None]
vector._create_inverted_index(cursor)
assert cursor.execute.call_count == 2
cursor.reset_mock()
cursor.execute.side_effect = [
None,
RuntimeError("already has index"),
None,
]
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
vector._create_inverted_index(cursor)
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("other create failure")]
cursor.fetchall.return_value = []
vector._create_inverted_index(cursor)
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.batch_size = 2
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id))
docs = [
Document(page_content="doc-1", metadata={"doc_id": "id-1"}),
Document(page_content="doc-2", metadata={"doc_id": "id-2"}),
Document(page_content="doc-3", metadata={"doc_id": "id-3"}),
]
vectors = [[0.1], [0.2], [0.3]]
vector.add_texts([], [])
vector._execute_write.assert_not_called()
added_ids = vector.add_texts(docs, vectors)
assert added_ids == ["id-1", "id-2", "id-3"]
assert vector._execute_write.call_count == 2
assert vector._execute_write.call_args_list[0].args == (
vector._insert_batch,
docs[:2],
vectors[:2],
["id-1", "id-2"],
0,
2,
2,
)
assert vector._execute_write.call_args_list[1].args == (
vector._insert_batch,
docs[2:],
vectors[2:],
["id-3"],
2,
2,
2,
)
vector._insert_batch([], [], [], 0, 2, 1)
vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1)
bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}})
good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._insert_batch(
[bad_doc, good_doc],
[[0.1, 0.2], [0.3, 0.4]],
["id-bad", "id-good"],
0,
2,
1,
)
@contextmanager
def _ctx_error():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.executemany.side_effect = RuntimeError("insert failed")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_error
with pytest.raises(RuntimeError, match="insert failed"):
vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1)
monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id")
vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector)
assert vector._safe_doc_id("") == "generated-id"
assert vector._safe_doc_id("!!!") == "generated-id"
def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector._table_exists = MagicMock(return_value=True)
vector.delete_by_ids(["id-1", "id-2"])
vector._execute_write.assert_called_once()
assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl
vector._execute_write.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector._execute_write.assert_called_once()
assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl
vector._safe_doc_id = MagicMock(side_effect=lambda x: x)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._delete_by_ids_impl(["id-1", "id-2"])
vector._delete_by_metadata_field_impl("document_id", "doc-1")
def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.vector_distance_function = "cosine_distance"
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
cosine_docs = vector.search_by_vector(
[0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"}
)
assert cosine_docs[0].metadata["score"] == pytest.approx(0.9)
vector._config.vector_distance_function = "l2_distance"
l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2)
def test_search_by_full_text_success_and_fallback(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx_success():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [
("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'),
("seg-2", "content-2", "invalid-json"),
]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_success
docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1})
assert len(docs) == 2
assert docs[0].metadata["score"] == 1.0
assert docs[1].metadata["doc_id"] == "seg-2"
@contextmanager
def _ctx_failure():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("full text failed")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_failure
vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})])
fallback_docs = vector.search_by_full_text("query", top_k=1)
assert fallback_docs == vector._search_by_like.return_value
def test_search_by_like_missing_table_and_delete_table(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector._search_by_like("query", top_k=1) == []
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector.delete()
def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._pools = {}
pool._pool_locks = {}
pool._max_pool_size = 1
pool._connection_timeout = 10
pool._lock = clickzetta_module.threading.Lock()
pool._shutdown = False
config = _config(clickzetta_module)
key = pool._get_config_key(config)
pool._pools[key] = [(MagicMock(), 1.0)]
pool._pool_locks[key] = clickzetta_module.threading.Lock()
pool._is_connection_valid = MagicMock(return_value=False)
conn = MagicMock()
pool.return_connection(config, conn)
conn.close.assert_called_once()
pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)]
pool._cleanup_expired_connections()
pool.shutdown()
assert pool._shutdown is True
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
def _cleanup_then_fail():
pool._shutdown = True
raise RuntimeError("cleanup failed")
pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail)
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
class _Thread:
def __init__(self, target, daemon):
self._target = target
self.daemon = daemon
def start(self):
self._target()
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
pool._start_cleanup_thread()
pool._cleanup_expired_connections.assert_called_once()
def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1")
assert parsed_non_dict["doc_id"] == "row-1"
assert parsed_non_dict["document_id"] == "row-1"
parsed_none = vector._parse_metadata(None, "row-2")
assert parsed_none["doc_id"] == "row-2"
assert parsed_none["document_id"] == "row-2"
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue = None
clickzetta_module.ClickzettaVector._write_worker()
class _BadQueue:
def get(self, timeout):
clickzetta_module.ClickzettaVector._shutdown = True
raise RuntimeError("queue failed")
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue = _BadQueue()
clickzetta_module.ClickzettaVector._write_worker()
def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
cursor.execute.side_effect = [
None,
RuntimeError("already has index with the same type cannot create inverted index"),
None,
]
vector._create_inverted_index(cursor)
vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value))
@contextmanager
def _ctx():
connection = MagicMock()
cursor_obj = MagicMock()
cursor_obj.__enter__.return_value = cursor_obj
cursor_obj.__exit__.return_value = None
connection.cursor.return_value = cursor_obj
yield connection
vector.get_connection_context = _ctx
vector._insert_batch(
[SimpleNamespace(page_content="content", metadata="not-a-dict")],
[[0.1, 0.2]],
["doc-1"],
0,
1,
1,
)
def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.enable_inverted_index = True
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.search_by_full_text("query") == []
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [
("seg-1", "content-1", "[1,2,3]"),
("seg-2", "content-2", None),
]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
docs = vector.search_by_full_text("query")
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[1].metadata["doc_id"] == "seg-2"

View File

@ -0,0 +1,364 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_couchbase_modules():
couchbase = types.ModuleType("couchbase")
couchbase_auth = types.ModuleType("couchbase.auth")
couchbase_cluster = types.ModuleType("couchbase.cluster")
couchbase_management = types.ModuleType("couchbase.management")
couchbase_management_search = types.ModuleType("couchbase.management.search")
couchbase_options = types.ModuleType("couchbase.options")
couchbase_vector = types.ModuleType("couchbase.vector_search")
couchbase_search = types.ModuleType("couchbase.search")
class PasswordAuthenticator:
def __init__(self, user, password):
self.user = user
self.password = password
class ClusterOptions:
def __init__(self, auth):
self.auth = auth
class SearchOptions:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorQuery:
def __init__(self, field, vector, top_k):
self.field = field
self.vector = vector
self.top_k = top_k
class VectorSearch:
@staticmethod
def from_vector_query(vector_query):
return {"vector_query": vector_query}
class QueryStringQuery:
def __init__(self, query):
self.query = query
class SearchRequest:
@staticmethod
def create(payload):
return {"payload": payload}
class SearchIndex:
def __init__(self, name, params, source_name):
self.name = name
self.params = params
self.source_name = source_name
class _QueryResult:
def __init__(self, rows=None):
self._rows = rows or []
def execute(self):
return self
def __iter__(self):
return iter(self._rows)
class _SearchIter:
def __init__(self, rows=None):
self._rows = rows or []
def rows(self):
return self._rows
class _Collection:
def __init__(self):
self.upsert = MagicMock(return_value=True)
class _SearchIndexManager:
def __init__(self):
self.upsert_index = MagicMock()
class _Scope:
def __init__(self):
self._collection = _Collection()
self._search_index_manager = _SearchIndexManager()
self.search = MagicMock(return_value=_SearchIter())
def collection(self, _name):
return self._collection
def search_indexes(self):
return self._search_index_manager
class _CollectionManager:
def __init__(self):
self.create_collection = MagicMock()
self.drop_collection = MagicMock()
self.get_all_scopes = MagicMock(return_value=[])
class _Bucket:
def __init__(self):
self._scope = _Scope()
self._collections = _CollectionManager()
def scope(self, _scope_name):
return self._scope
def collections(self):
return self._collections
class Cluster:
def __init__(self, connection_string, options):
self.connection_string = connection_string
self.options = options
self._bucket = _Bucket()
self.wait_until_ready = MagicMock()
self.query = MagicMock(return_value=_QueryResult())
def bucket(self, _name):
return self._bucket
couchbase_auth.PasswordAuthenticator = PasswordAuthenticator
couchbase_cluster.Cluster = Cluster
couchbase_management_search.SearchIndex = SearchIndex
couchbase_options.ClusterOptions = ClusterOptions
couchbase_options.SearchOptions = SearchOptions
couchbase_vector.VectorQuery = VectorQuery
couchbase_vector.VectorSearch = VectorSearch
couchbase_search.QueryStringQuery = QueryStringQuery
couchbase_search.SearchRequest = SearchRequest
couchbase.search = couchbase_search
couchbase.management = couchbase_management
return {
"couchbase": couchbase,
"couchbase.auth": couchbase_auth,
"couchbase.cluster": couchbase_cluster,
"couchbase.management": couchbase_management,
"couchbase.management.search": couchbase_management_search,
"couchbase.options": couchbase_options,
"couchbase.vector_search": couchbase_vector,
"couchbase.search": couchbase_search,
}
@pytest.fixture
def couchbase_module(monkeypatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
return importlib.reload(module)
def _config(module):
return module.CouchbaseConfig(
connection_string="couchbase://localhost",
user="user",
password="pass",
bucket_name="bucket",
scope_name="scope",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("connection_string", "", "CONNECTION_STRING is required"),
("user", "", "COUCHBASE_USER is required"),
("password", "", "COUCHBASE_PASSWORD is required"),
("bucket_name", "", "COUCHBASE_PASSWORD is required"),
("scope_name", "", "COUCHBASE_SCOPE_NAME is required"),
],
)
def test_couchbase_config_validation(couchbase_module, field, value, message):
values = _config(couchbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
couchbase_module.CouchbaseConfig.model_validate(values)
def test_init_sets_cluster_handles(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
assert vector._bucket_name == "bucket"
assert vector._scope_name == "scope"
vector._cluster.wait_until_ready.assert_called_once()
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
vector._collection_name = "collection_1"
vector._client_config = _config(couchbase_module)
vector._scope_name = "scope"
vector._bucket_name = "bucket"
vector._bucket = MagicMock()
vector._scope = MagicMock()
vector._collection_exists = MagicMock(return_value=False)
vector.add_texts = MagicMock()
monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c")
vector._create_collection = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock())
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(vector_length=2, uuid="uuid-1")
vector._bucket.collections().create_collection.assert_not_called()
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None))
vector._collection_exists = MagicMock(return_value=True)
vector._create_collection(vector_length=2, uuid="uuid-2")
vector._bucket.collections().create_collection.assert_not_called()
vector._collection_exists = MagicMock(return_value=False)
vector._create_collection(vector_length=3, uuid="uuid-3")
vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1")
vector._scope.search_indexes().upsert_index.assert_called_once()
search_index = vector._scope.search_indexes().upsert_index.call_args.args[0]
assert search_index.name == "collection_1_search"
assert (
search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"]
== 3
)
couchbase_module.redis_client.set.assert_called_once()
def test_collection_exists_get_type_and_add_texts(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")])
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
assert vector._collection_exists("collection_1") is True
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")])
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
assert vector._collection_exists("collection_1") is False
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._scope.collection("collection_1").upsert.call_count == 2
assert vector.get_type() == couchbase_module.VectorType.COUCHBASE
def test_query_delete_helpers(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}]))
assert vector.text_exists("id-1") is True
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([]))
assert vector.text_exists("id-2") is False
query_result = MagicMock()
query_result.execute.return_value = None
vector._cluster.query.return_value = query_result
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_document_id("id-1")
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector._cluster.query.call_count >= 3
vector._cluster.query.side_effect = RuntimeError("delete failed")
vector.delete_by_ids(["id-3"])
def test_search_methods_and_format_metadata(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9)
row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3)
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2])
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
assert docs[0].metadata["document_id"] == "d-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
vector._scope.search.side_effect = RuntimeError("search error")
with pytest.raises(ValueError, match="Search failed"):
vector.search_by_vector([0.1], top_k=1)
vector._scope.search.side_effect = None
row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7)
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3])
docs = vector.search_by_full_text("hello", top_k=1)
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == "x"
vector._scope.search.side_effect = RuntimeError("full text failed")
with pytest.raises(ValueError, match="Search failed"):
vector.search_by_full_text("hello", top_k=1)
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scopes = [
SimpleNamespace(collections=[SimpleNamespace(name="other")]),
SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]),
]
vector._bucket.collections().get_all_scopes.return_value = scopes
vector.delete()
vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1")
factory = couchbase_module.CouchbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
couchbase_module,
"current_app",
SimpleNamespace(
config={
"COUCHBASE_CONNECTION_STRING": "couchbase://localhost",
"COUCHBASE_USER": "user",
"COUCHBASE_PASSWORD": "pass",
"COUCHBASE_BUCKET_NAME": "bucket",
"COUCHBASE_SCOPE_NAME": "scope",
}
),
)
with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,121 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class ConnectionError(Exception):
pass
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.ping = MagicMock(return_value=True)
self.info = MagicMock(return_value={"version": {"number": "8.12.0"}})
self.indices = SimpleNamespace(
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
)
elasticsearch.Elasticsearch = Elasticsearch
elasticsearch.ConnectionError = ConnectionError
return {"elasticsearch": elasticsearch}
@pytest.fixture
def elasticsearch_ja_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
importlib.reload(base_module)
return importlib.reload(ja_module)
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
vector._collection_name = "test"
vector._client = MagicMock()
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_ja_module.redis_client.set.assert_not_called()
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
vector._collection_name = "test"
vector._client = MagicMock()
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2, 0.3]], [{}])
vector._client.indices.create.assert_called_once()
kwargs = vector._client.indices.create.call_args.kwargs
assert kwargs["index"] == "test"
assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3
elasticsearch_ja_module.redis_client.set.assert_called_once()
vector._client.indices.create.reset_mock()
elasticsearch_ja_module.redis_client.set.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_ja_module.redis_client.set.assert_called_once()
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
elasticsearch_ja_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_HOST": "localhost",
"ELASTICSEARCH_PORT": 9200,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
}
),
)
with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,405 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class ConnectionError(Exception):
pass
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.ping = MagicMock(return_value=True)
self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}})
self.index = MagicMock()
self.exists = MagicMock(return_value=False)
self.delete = MagicMock()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.indices = SimpleNamespace(
refresh=MagicMock(),
delete=MagicMock(),
exists=MagicMock(return_value=False),
create=MagicMock(),
)
elasticsearch.Elasticsearch = Elasticsearch
elasticsearch.ConnectionError = ConnectionError
return {"elasticsearch": elasticsearch}
@pytest.fixture
def elasticsearch_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
return importlib.reload(module)
def _regular_config(module, **overrides):
values = {
"host": "localhost",
"port": 9200,
"username": "elastic",
"password": "secret",
"verify_certs": False,
"request_timeout": 10,
"retry_on_timeout": True,
"max_retries": 3,
}
values.update(overrides)
return module.ElasticSearchConfig.model_validate(values)
def _cloud_config(module, **overrides):
values = {
"use_cloud": True,
"cloud_url": "https://cloud.example:9243",
"api_key": "api-key",
"verify_certs": True,
"ca_certs": "/tmp/ca.pem",
"request_timeout": 10,
"retry_on_timeout": True,
"max_retries": 3,
}
values.update(overrides)
return module.ElasticSearchConfig.model_validate(values)
@pytest.mark.parametrize(
("values", "message"),
[
({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"),
({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"),
({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"),
({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"),
({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"),
({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"),
],
)
def test_elasticsearch_config_validation(elasticsearch_module, values, message):
with pytest.raises(ValidationError, match=message):
elasticsearch_module.ElasticSearchConfig.model_validate(values)
def test_init_client_cloud_configuration(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = True
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
result = vector._init_client(_cloud_config(elasticsearch_module))
assert result is client
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["https://cloud.example:9243"]
assert kwargs["api_key"] == "api-key"
assert kwargs["verify_certs"] is True
assert kwargs["ca_certs"] == "/tmp/ca.pem"
def test_init_client_regular_https_and_http_fallback(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = True
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
vector._init_client(
_regular_config(
elasticsearch_module,
host="https://es.example",
port=9443,
verify_certs=True,
ca_certs="/tmp/ca.pem",
)
)
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["https://es.example:9443"]
assert kwargs["verify_certs"] is True
assert kwargs["ca_certs"] == "/tmp/ca.pem"
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200))
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["http://es.internal:9200"]
assert "verify_certs" not in kwargs
def test_init_client_connection_failures(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = False
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client):
with pytest.raises(ConnectionError, match="Failed to connect"):
vector._init_client(_regular_config(elasticsearch_module))
with patch.object(
elasticsearch_module,
"Elasticsearch",
side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"),
):
with pytest.raises(ConnectionError, match="Vector database connection error"):
vector._init_client(_regular_config(elasticsearch_module))
with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")):
with pytest.raises(ConnectionError, match="initialization failed"):
vector._init_client(_regular_config(elasticsearch_module))
def test_init_get_version_and_check_version(elasticsearch_module):
with (
patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client,
patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version,
patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version,
):
vector = elasticsearch_module.ElasticSearchVector(
"collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"]
)
init_client.assert_called_once()
get_version.assert_called_once()
check_version.assert_called_once()
assert vector._attributes == ["doc_id"]
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._client = MagicMock()
vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}}
assert vector._get_version() == "8.13.2"
vector._version = "7.17.0"
with pytest.raises(ValueError, match="greater than 8.0.0"):
vector._check_version()
vector._version = "8.0.0"
vector._check_version()
def test_crud_methods_and_get_type(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock())
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.index.call_count == 2
vector._client.indices.refresh.assert_called_once_with(index="collection_1")
vector._client.exists.return_value = True
assert vector.text_exists("id-1") is True
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
assert vector._client.delete.call_count == 2
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "d1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.delete_by_ids.reset_mock()
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_metadata_field("doc_id", "d2")
vector.delete_by_ids.assert_not_called()
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection_1")
assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH
def test_search_by_vector_and_full_text(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.8,
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "doc-a",
elasticsearch_module.Field.VECTOR: [0.1],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
},
},
{
"_score": 0.2,
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "doc-b",
elasticsearch_module.Field.VECTOR: [0.2],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
},
},
]
}
}
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.8)
knn = vector._client.search.call_args.kwargs["knn"]
assert knn["k"] == 2
assert knn["num_candidates"] == 3
assert "filter" in knn
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "text-hit",
elasticsearch_module.Field.VECTOR: [0.3],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"])
assert len(docs) == 1
assert docs[0].page_content == "text-hit"
query = vector._client.search.call_args.kwargs["query"]
assert "bool" in query
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock())
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_called_once()
mappings = vector._client.indices.create.call_args.kwargs["mappings"]
assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2
elasticsearch_module.redis_client.set.assert_called_once()
vector._client.indices.create.reset_mock()
elasticsearch_module.redis_client.set.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_module.redis_client.set.assert_called_once()
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
factory = elasticsearch_module.ElasticSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": False,
"ELASTICSEARCH_HOST": "es-host",
"ELASTICSEARCH_PORT": 9200,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
"ELASTICSEARCH_VERIFY_CERTS": False,
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is False
assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION"
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": True,
"ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic",
"ELASTICSEARCH_API_KEY": "api-key",
"ELASTICSEARCH_VERIFY_CERTS": True,
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_2 == "vector"
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is True
assert cfg.cloud_url == "https://cloud.elastic"
assert dataset_without_index.index_struct is not None
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": True,
"ELASTICSEARCH_CLOUD_URL": None,
"ELASTICSEARCH_HOST": "fallback-host",
"ELASTICSEARCH_PORT": 9201,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is False
assert cfg.host == "fallback-host"

View File

@ -0,0 +1,371 @@
import importlib
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_hologres_modules():
holo_module = types.ModuleType("holo_search_sdk")
holo_types_module = types.ModuleType("holo_search_sdk.types")
holo_types_module.BaseQuantizationType = str
holo_types_module.DistanceType = str
holo_types_module.TokenizerType = str
def _connect(**kwargs):
client = MagicMock()
client.kwargs = kwargs
client.connect = MagicMock()
client.check_table_exist = MagicMock(return_value=False)
client.open_table = MagicMock(return_value=MagicMock())
client.execute = MagicMock(return_value=[])
client.drop_table = MagicMock()
return client
holo_module.connect = MagicMock(side_effect=_connect)
return {
"holo_search_sdk": holo_module,
"holo_search_sdk.types": holo_types_module,
}
@pytest.fixture
def hologres_module(monkeypatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.hologres.hologres_vector as module
return importlib.reload(module)
def _valid_config(module):
return module.HologresVectorConfig(
host="localhost",
port=80,
database="dify",
access_key_id="ak",
access_key_secret="sk",
schema_name="public",
tokenizer="jieba",
distance_method="Cosine",
base_quantization_type="rabitq",
max_degree=64,
ef_construction=400,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config HOLOGRES_HOST is required"),
("database", "", "config HOLOGRES_DATABASE is required"),
("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"),
("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"),
],
)
def test_hologres_config_validation(hologres_module, field, value, message):
values = _valid_config(hologres_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
hologres_module.HologresVectorConfig.model_validate(values)
def test_init_client_and_get_type(hologres_module):
vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module))
hologres_module.holo.connect.assert_called_once_with(
host="localhost",
port=80,
database="dify",
access_key_id="ak",
access_key_secret="sk",
schema="public",
)
vector._client.connect.assert_called_once()
assert vector.table_name == "embedding_collection_one"
assert vector.get_type() == hologres_module.VectorType.HOLOGRES
def test_create_delegates_collection_creation_and_upsert(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result is None
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_returns_empty_for_empty_documents(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
assert vector.add_texts([], []) == []
vector._client.open_table.assert_not_called()
def test_add_texts_batches_and_serializes_metadata(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
table = vector._client.open_table.return_value
documents = [
Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"})
for i in range(100)
]
documents.append(SimpleNamespace(page_content="doc-100", metadata=None))
embeddings = [[float(i)] for i in range(len(documents))]
ids = vector.add_texts(documents, embeddings)
assert ids[:2] == ["id-0", "id-1"]
assert ids[-1] == ""
assert len(ids) == 101
assert vector._client.open_table.call_count == 2
assert table.upsert_multi.call_count == 2
first_call = table.upsert_multi.call_args_list[0].kwargs
second_call = table.upsert_multi.call_args_list[1].kwargs
assert first_call["index_column"] == "id"
assert first_call["column_names"] == ["id", "text", "meta", "embedding"]
assert first_call["update_columns"] == ["text", "meta", "embedding"]
assert len(first_call["values"]) == 100
assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"}
assert second_call["values"][0][0] == ""
assert second_call["values"][0][2] == "{}"
def test_text_exists_handles_missing_and_present_tables(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, True]
vector._client.execute.return_value = [(1,)]
assert vector.text_exists("seg-1") is False
assert vector.text_exists("seg-1") is True
vector._client.execute.assert_called_once()
def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
def test_delete_by_ids_branches(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector.delete_by_ids([])
vector._client.check_table_exist.assert_not_called()
vector._client.check_table_exist.return_value = False
vector.delete_by_ids(["id-1"])
vector._client.execute.assert_not_called()
vector._client.check_table_exist.return_value = True
vector.delete_by_ids(["id-1", "id-2"])
vector._client.execute.assert_called_once()
def test_delete_by_metadata_field_branches(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.execute.assert_not_called()
vector._client.check_table_exist.return_value = True
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.execute.assert_called_once()
def test_search_by_vector_returns_empty_when_table_missing(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
assert vector.search_by_vector([0.1, 0.2]) == []
def test_search_by_vector_applies_filter_and_processes_results(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = True
table = vector._client.open_table.return_value
query = MagicMock()
table.search_vector.return_value = query
query.select.return_value = query
query.limit.return_value = query
query.where.return_value = query
query.fetchall.return_value = [
(0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'),
(0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}),
]
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["doc-1"],
)
assert len(docs) == 1
assert docs[0].page_content == "doc-1"
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[0].metadata["score"] == pytest.approx(0.8)
table.search_vector.assert_called_once()
query.where.assert_called_once()
def test_search_by_full_text_returns_empty_when_table_missing(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
assert vector.search_by_full_text("query") == []
def test_search_by_full_text_applies_filter_and_processes_results(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = True
table = vector._client.open_table.return_value
search_query = MagicMock()
table.search_text.return_value = search_query
search_query.limit.return_value = search_query
search_query.where.return_value = search_query
search_query.fetchall.return_value = [
("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95),
("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7),
]
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"])
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[0].metadata["score"] == pytest.approx(0.95)
assert docs[1].metadata["score"] == pytest.approx(0.7)
table.search_text.assert_called_once()
search_query.where.assert_called_once()
def test_delete_handles_existing_and_missing_tables(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, True]
vector.delete()
vector._client.drop_table.assert_not_called()
vector.delete()
vector._client.drop_table.assert_called_once_with(vector.table_name)
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._create_collection(3)
vector._client.check_table_exist.assert_not_called()
hologres_module.redis_client.set.assert_not_called()
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, False, True]
table = vector._client.open_table.return_value
vector._create_collection(3)
vector._client.execute.assert_called_once()
table.set_vector_index.assert_called_once_with(
column="embedding",
distance_method="Cosine",
base_quantization_type="rabitq",
max_degree=64,
ef_construction=400,
use_reorder=True,
)
table.create_text_index.assert_called_once_with(
index_name="ft_idx_collection_one",
column="text",
tokenizer="jieba",
)
hologres_module.redis_client.set.assert_called_once()
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False] + [False] * 15
with pytest.raises(RuntimeError, match="was not ready after 30s"):
vector._create_collection(3)
hologres_module.redis_client.set.assert_not_called()
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
factory = hologres_module.HologresVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80)
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64)
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400)
with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection"
generated_config = vector_cls.call_args_list[1].kwargs["config"]
assert generated_config.host == "127.0.0.1"
assert generated_config.database == "dify"
assert generated_config.access_key_id == "ak"
assert json.loads(dataset_without_index.index_struct) == {
"type": hologres_module.VectorType.HOLOGRES,
"vector_store": {"class_prefix": "generated_collection"},
}

View File

@ -0,0 +1,243 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.index = MagicMock()
self.exists = MagicMock(return_value=False)
self.delete = MagicMock()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.indices = SimpleNamespace(
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
)
elasticsearch.Elasticsearch = Elasticsearch
return {"elasticsearch": elasticsearch}
@pytest.fixture
def huawei_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module
return importlib.reload(module)
def _config(module):
return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass")
def test_create_ssl_context(huawei_module):
ctx = huawei_module.create_ssl_context()
assert ctx.check_hostname is False
assert ctx.verify_mode == huawei_module.ssl.CERT_NONE
def test_huawei_config_validation_and_params(huawei_module):
with pytest.raises(ValidationError, match="HOSTS is required"):
huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""})
config = _config(huawei_module)
params = config.to_elasticsearch_params()
assert params["hosts"] == ["http://localhost:9200"]
assert params["basic_auth"] == ("user", "pass")
config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None)
params = config.to_elasticsearch_params()
assert "basic_auth" not in params
def test_init_get_type_and_add_texts(huawei_module):
vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module))
assert vector._collection_name == "collection"
assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.index.call_count == 2
vector._client.indices.refresh.assert_called_once_with(index="collection")
def test_crud_methods(huawei_module):
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.exists.return_value = True
assert vector.text_exists("id-1") is True
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["id-1"])
vector._client.delete.assert_called_once_with(index="collection", id="id-1")
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "x")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.delete_by_ids.reset_mock()
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_metadata_field("doc_id", "x")
vector.delete_by_ids.assert_not_called()
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection")
def test_search_by_vector_and_full_text(huawei_module):
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-a",
huawei_module.Field.VECTOR: [0.1],
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
},
},
{
"_score": 0.1,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-b",
huawei_module.Field.VECTOR: [0.2],
huawei_module.Field.METADATA_KEY: {"doc_id": "2"},
},
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
query_body = vector._client.search.call_args.kwargs["body"]
assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
huawei_module.Field.CONTENT_KEY: "text-hit",
huawei_module.Field.VECTOR: [0.3],
huawei_module.Field.METADATA_KEY: {"doc_id": "3"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=3)
assert len(docs) == 1
assert docs[0].page_content == "text-hit"
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
class FakeDocument:
def __init__(self, page_content, vector, metadata):
self.page_content = page_content
self.vector = vector
self.metadata = None
monkeypatch.setattr(huawei_module, "Document", FakeDocument)
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-a",
huawei_module.Field.VECTOR: [0.1],
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
},
}
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5)
assert docs == []
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock())
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_called_once()
kwargs = vector._client.indices.create.call_args.kwargs
mappings = kwargs["mappings"]
assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2
assert kwargs["settings"] == {"index.vector": True}
huawei_module.redis_client.set.assert_called_once()
def test_huawei_factory_branches(huawei_module, monkeypatch):
factory = huawei_module.HuaweiCloudVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass")
with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,412 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_iris_module():
iris = types.ModuleType("iris")
def connect(**_kwargs):
conn = MagicMock()
conn.cursor.return_value = MagicMock()
return conn
iris.connect = MagicMock(side_effect=connect)
return iris
@pytest.fixture
def iris_module(monkeypatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import core.rag.datasource.vdb.iris.iris_vector as module
reloaded = importlib.reload(module)
reloaded._pool_instance = None
return reloaded
def _config(module, **overrides):
values = {
"IRIS_HOST": "localhost",
"IRIS_SUPER_SERVER_PORT": 1972,
"IRIS_USER": "user",
"IRIS_PASSWORD": "pass",
"IRIS_DATABASE": "db",
"IRIS_SCHEMA": "schema",
"IRIS_CONNECTION_URL": "url",
"IRIS_MIN_CONNECTION": 1,
"IRIS_MAX_CONNECTION": 2,
"IRIS_TEXT_INDEX": True,
"IRIS_TEXT_INDEX_LANGUAGE": "en",
}
values.update(overrides)
return module.IrisVectorConfig.model_validate(values)
def test_get_iris_pool_singleton(iris_module):
iris_module._pool_instance = None
cfg = _config(iris_module)
with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls:
pool_1 = iris_module.get_iris_pool(cfg)
pool_2 = iris_module.get_iris_pool(cfg)
assert pool_1 == "pool"
assert pool_2 == "pool"
pool_cls.assert_called_once_with(cfg)
@pytest.fixture
def pool_with_min_max(iris_module):
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn:
pool = iris_module.IrisConnectionPool(cfg)
yield pool, create_conn
def test_pool_initialization_respects_min_max(pool_with_min_max):
pool, create_conn = pool_with_min_max
assert len(pool._pool) == 2
assert create_conn.call_count == 2
@pytest.fixture
def pool_for_get_connection(iris_module):
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
pool = iris_module.IrisConnectionPool(cfg)
return pool
def test_get_connection_returns_existing_and_increments(pool_for_get_connection):
pool = pool_for_get_connection
conn = MagicMock()
pool._pool = [conn]
pool._in_use = 0
assert pool.get_connection() is conn
assert pool._in_use == 1
def test_get_connection_creates_new_when_empty(pool_for_get_connection):
pool = pool_for_get_connection
pool._pool = []
pool._in_use = 0
pool._create_connection = MagicMock(return_value="new-conn")
assert pool.get_connection() == "new-conn"
def test_get_connection_raises_when_exhausted(pool_for_get_connection):
pool = pool_for_get_connection
pool._pool = []
pool._in_use = pool._max_size
with pytest.raises(RuntimeError, match="exhausted"):
pool.get_connection()
@pytest.fixture
def pool_for_return_connection(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
return pool
def test_return_connection_adds_healthy(pool_for_return_connection):
pool = pool_for_return_connection
pool._in_use = 1
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
pool.return_connection(conn)
assert pool._pool[-1] is conn
assert pool._in_use == 0
def test_return_connection_replaces_bad(pool_for_return_connection):
pool = pool_for_return_connection
pool._in_use = 1
bad_conn = MagicMock()
bad_cursor = MagicMock()
bad_cursor.execute.side_effect = OSError("bad")
bad_conn.cursor.return_value = bad_cursor
replacement = MagicMock()
pool._create_connection = MagicMock(return_value=replacement)
pool.return_connection(bad_conn)
bad_conn.close.assert_called_once()
assert pool._pool[-1] is replacement
assert pool._in_use == 0
def test_return_connection_ignores_none(pool_for_return_connection):
pool = pool_for_return_connection
before = len(pool._pool)
pool.return_connection(None)
assert len(pool._pool) == before
@pytest.fixture
def pool_for_schema_and_close(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
pool._pool = [conn]
return pool, conn, cursor
def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = {"cached_schema"}
pool.ensure_schema_exists("cached_schema")
cursor.execute.assert_not_called()
def test_ensure_schema_exists_creates_new(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.fetchone.return_value = (0,)
pool.ensure_schema_exists("new_schema")
assert "new_schema" in pool._schemas_initialized
assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list)
conn.commit.assert_called_once()
def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.fetchone.return_value = (1,)
pool.ensure_schema_exists("existing_schema")
conn.commit.assert_not_called()
def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.execute.side_effect = RuntimeError("schema failure")
with pytest.raises(RuntimeError, match="schema failure"):
pool.ensure_schema_exists("broken_schema")
conn.rollback.assert_called()
def test_close_all_closes_and_resets(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
conn = MagicMock()
conn_2 = MagicMock()
conn_2.close.side_effect = OSError("close fail")
pool._pool = [conn, conn_2]
pool._schemas_initialized = {"x"}
pool.close_all()
assert pool._pool == []
assert pool._in_use == 0
assert pool._schemas_initialized == set()
def test_iris_vector_init_get_cursor_and_create(iris_module):
pool = MagicMock()
pool.get_connection.return_value = MagicMock()
with patch.object(iris_module, "get_iris_pool", return_value=pool):
vector = iris_module.IrisVector("collection", _config(iris_module))
assert vector.table_name == "EMBEDDING_COLLECTION"
assert vector.schema == "schema"
assert vector.get_type() == iris_module.VectorType.IRIS
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
vector.pool.get_connection.return_value = conn
with vector._get_cursor() as got_cursor:
assert got_cursor is cursor
conn.commit.assert_called_once()
vector.pool.return_connection.assert_called_with(conn)
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
vector.pool.get_connection.return_value = conn
with pytest.raises(RuntimeError, match="boom"):
with vector._get_cursor():
raise RuntimeError("boom")
conn.rollback.assert_called_once()
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["id-1"])
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"]
vector._create_collection.assert_called_once_with(2)
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id")
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
SimpleNamespace(page_content="b", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "generated-id"]
assert cursor.execute.call_count == 2
cursor.fetchone.return_value = (1,)
assert vector.text_exists("id-1") is True
cursor.fetchone.return_value = None
assert vector.text_exists("id-2") is False
vector._get_cursor = MagicMock(side_effect=RuntimeError("db down"))
assert vector.text_exists("id-3") is False
vector._get_cursor = _cursor_ctx
vector.delete_by_ids([])
before = cursor.execute.call_count
vector.delete_by_ids(["id-1", "id-2"])
assert cursor.execute.call_count == before + 1
vector.delete_by_metadata_field("document_id", "doc-1")
assert "meta LIKE" in cursor.execute.call_args.args[0]
cursor.fetchall.return_value = [
("id-1", "text-1", '{"document_id":"d-1"}', 0.9),
("id-2", "text-2", '{"document_id":"d-2"}', 0.2),
("id-x",),
]
docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", cfg)
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
cursor.execute.side_effect = None
cursor.fetchall.return_value = [
("id-1", "text-1", '{"document_id":"d-1"}', 0.7),
("id-2", "text-2", "{}", None),
]
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert docs[0].metadata["score"] == pytest.approx(0.7)
assert docs[1].metadata["score"] == pytest.approx(0.0)
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("rank failed"), None]
cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)]
docs = vector.search_by_full_text("query", top_k=1)
assert len(docs) == 1
assert cursor.execute.call_count == 2
cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector_like = iris_module.IrisVector("collection", cfg_like)
vector_like._get_cursor = _cursor_ctx
fake_libs = types.ModuleType("libs")
fake_helper = types.ModuleType("libs.helper")
fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%")
monkeypatch.setitem(sys.modules, "libs", fake_libs)
monkeypatch.setitem(sys.modules, "libs.helper", fake_helper)
cursor.reset_mock()
cursor.execute.side_effect = None
cursor.fetchall.return_value = []
assert vector_like.search_by_full_text("100%", top_k=1) == []
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete()
assert "DROP TABLE" in cursor.execute.call_args.args[0]
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(iris_module.redis_client, "set", MagicMock())
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(2)
cursor.execute.assert_called_once()
cursor.reset_mock()
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None))
vector.pool.ensure_schema_exists = MagicMock()
vector._create_collection(3)
assert cursor.execute.call_count == 3
iris_module.redis_client.set.assert_called_once()
cursor.reset_mock()
vector.config.IRIS_TEXT_INDEX = False
vector._create_collection(3)
assert cursor.execute.call_count == 2
factory = iris_module.IrisVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost")
monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972)
monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user")
monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass")
monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db")
monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema")
monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url")
monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1)
monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2)
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True)
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en")
with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,394 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_opensearch_modules():
opensearchpy = types.ModuleType("opensearchpy")
opensearch_helpers = types.ModuleType("opensearchpy.helpers")
class BulkIndexError(Exception):
def __init__(self, errors):
super().__init__("bulk error")
self.errors = errors
class OpenSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.indices = SimpleNamespace(
refresh=MagicMock(),
exists=MagicMock(return_value=False),
delete=MagicMock(),
create=MagicMock(),
)
self.bulk = MagicMock(return_value={"errors": False, "items": []})
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.delete_by_query = MagicMock()
self.get = MagicMock(return_value={"_id": "id"})
self.exists = MagicMock(return_value=True)
opensearch_helpers.BulkIndexError = BulkIndexError
opensearch_helpers.bulk = MagicMock()
opensearchpy.OpenSearch = OpenSearch
opensearchpy.helpers = opensearch_helpers
return {
"opensearchpy": opensearchpy,
"opensearchpy.helpers": opensearch_helpers,
}
@pytest.fixture
def lindorm_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.lindorm.lindorm_vector as module
return importlib.reload(module)
def _config(module):
return module.LindormVectorStoreConfig(
hosts="http://localhost:9200",
username="user",
password="pass",
using_ugc=False,
request_timeout=3.0,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("hosts", None, "config URL is required"),
("username", None, "config USERNAME is required"),
("password", None, "config PASSWORD is required"),
],
)
def test_lindorm_config_validation(lindorm_module, field, value, message):
values = _config(lindorm_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
lindorm_module.LindormVectorStoreConfig.model_validate(values)
def test_to_opensearch_params_and_init(lindorm_module):
cfg = _config(lindorm_module)
params = cfg.to_opensearch_params()
assert params["hosts"] == "http://localhost:9200"
assert params["http_auth"] == ("user", "pass")
vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False)
assert vector._collection_name == "collection"
assert vector.get_type() == lindorm_module.VectorType.LINDORM
with pytest.raises(ValueError, match="routing_value"):
lindorm_module.LindormVectorStore("c", cfg, using_ugc=True)
vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE")
assert vector_ugc._routing == "route"
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}])
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock())
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
Document(page_content="c", metadata={"doc_id": "id-3"}),
]
embeddings = [[0.1], [0.2], [0.3]]
vector.add_texts(docs, embeddings, batch_size=2, timeout=9)
assert vector._client.bulk.call_count == 2
actions = vector._client.bulk.call_args_list[0].args[0]
assert actions[0]["index"]["routing"] == "route"
assert actions[1][lindorm_module.ROUTING_FIELD] == "route"
vector.refresh()
vector._client.indices.refresh.assert_called_once_with(index="collection")
def test_add_texts_error_paths(lindorm_module):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]}
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
with pytest.raises(Exception, match="RetryError"):
vector.add_texts(docs, [[0.1]], batch_size=1)
vector._client.bulk.side_effect = RuntimeError("bulk failed")
with pytest.raises(Exception, match="RetryError"):
vector.add_texts(docs, [[0.1]], batch_size=1)
def test_metadata_lookup_and_delete_by_metadata(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
assert ids == ["id-1", "id-2"]
query = vector._client.search.call_args.kwargs["body"]
must_conditions = query["query"]["bool"]["must"]
assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions)
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"])
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_ids.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-2")
vector.delete_by_ids.assert_not_called()
def test_delete_by_ids_paths(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.delete_by_ids([])
vector._client.indices.exists.assert_not_called()
vector._client.indices.exists.return_value = False
vector.delete_by_ids(["id-1"])
vector._client.indices.exists.return_value = True
vector._client.exists.side_effect = [True, False]
lindorm_module.helpers.bulk.reset_mock()
vector.delete_by_ids(["id-1", "id-2"])
lindorm_module.helpers.bulk.assert_called_once()
actions = lindorm_module.helpers.bulk.call_args.args[1]
assert len(actions) == 1
assert actions[0]["routing"] == "route"
lindorm_module.helpers.bulk.reset_mock()
lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError(
errors=[
{"delete": {"status": 404, "_id": "id-404"}},
{"delete": {"status": 500, "_id": "id-500"}},
]
)
vector._client.exists.side_effect = [True]
vector.delete_by_ids(["id-1"])
def test_delete_and_text_exists(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.delete()
vector._client.delete_by_query.assert_called_once()
vector._client.indices.refresh.assert_called_once_with(index="collection")
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
vector._client.indices.exists.return_value = True
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60})
vector._client.indices.delete.reset_mock()
vector._client.indices.exists.return_value = False
vector.delete()
vector._client.indices.delete.assert_not_called()
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("missing")
assert vector.text_exists("id-1") is False
def test_search_by_vector_validation_and_success(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
with pytest.raises(ValueError, match="should be a list"):
vector.search_by_vector("bad")
with pytest.raises(ValueError, match="should be floats"):
vector.search_by_vector([0.1, "bad"])
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-a",
lindorm_module.Field.VECTOR: [0.1],
lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
},
},
{
"_score": 0.2,
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-b",
lindorm_module.Field.VECTOR: [0.2],
lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
},
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
call_kwargs = vector._client.search.call_args.kwargs
query = call_kwargs["body"]
assert "ext" in query
assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"]
assert call_kwargs["params"]["routing"] == "route"
vector._client.search.side_effect = RuntimeError("search failed")
with pytest.raises(RuntimeError, match="search failed"):
vector.search_by_vector([0.1])
def test_search_by_full_text_success_and_error(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-a",
lindorm_module.Field.VECTOR: [0.1],
lindorm_module.Field.METADATA_KEY: {"doc_id": "1"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
query = vector._client.search.call_args.kwargs["body"]
assert query["query"]["bool"]["filter"]
vector._client.search.side_effect = RuntimeError("full text failed")
with pytest.raises(RuntimeError, match="full text failed"):
vector.search_by_full_text("hello")
def test_create_collection_paths(lindorm_module, monkeypatch):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
with pytest.raises(ValueError, match="cannot be empty"):
vector.create_collection([])
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock())
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"})
vector._client.indices.create.assert_called_once()
body = vector._client.indices.create.call_args.kwargs["body"]
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf"
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine"
vector._client.indices.create.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
factory = lindorm_module.LindormVectorStoreFactory()
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0)
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2")
monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={})
embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3])
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None)
with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"):
factory.init_vector(dataset, attributes=[], embeddings=embeddings)
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
dataset_existing_plain = SimpleNamespace(
id="dataset-1",
index_struct="{}",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False},
)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings)
assert result == "vector"
assert store_cls.call_args.args[0] == "existing"
dataset_existing_ugc = SimpleNamespace(
id="dataset-1",
index_struct="{}",
index_struct_dict={
"vector_store": {"class_prefix": "ROUTING"},
"using_ugc": True,
"dimension": 1536,
"index_type": "hnsw",
"distance_type": "l2",
},
)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2"
assert store_cls.call_args.kwargs["routing_value"] == "ROUTING"
dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={})
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_new, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2"
assert store_cls.call_args.kwargs["routing_value"] == "auto_collection"
assert dataset_new.index_struct is not None
dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={})
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "auto_collection"
assert store_cls.call_args.kwargs["routing_value"] is None

View File

@ -0,0 +1,252 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_mo_vector_modules():
mo_vector = types.ModuleType("mo_vector")
mo_vector.__path__ = []
mo_vector_client = types.ModuleType("mo_vector.client")
class MoVectorClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_full_text_index = MagicMock()
self.insert = MagicMock()
self.get = MagicMock(return_value=[])
self.delete = MagicMock()
self.query_by_metadata = MagicMock(return_value=[])
self.query = MagicMock(return_value=[])
self.full_text_query = MagicMock(return_value=[])
mo_vector_client.MoVectorClient = MoVectorClient
mo_vector.client = mo_vector_client
return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client}
@pytest.fixture
def matrixone_module(monkeypatch):
for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.matrixone.matrixone_vector as module
return importlib.reload(module)
def _valid_config(module):
return module.MatrixoneConfig(
host="localhost",
port=6001,
user="dump",
password="111",
database="dify",
metric="l2",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config host is required"),
("port", 0, "config port is required"),
("user", "", "config user is required"),
("password", "", "config password is required"),
("database", "", "config database is required"),
],
)
def test_matrixone_config_validation(matrixone_module, field, value, message):
values = _valid_config(matrixone_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
matrixone_module.MatrixoneConfig.model_validate(values)
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
assert client.kwargs["table_name"] == "collection_1"
client.create_full_text_index.assert_called_once()
matrixone_module.redis_client.set.assert_called_once()
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
client.create_full_text_index.assert_not_called()
matrixone_module.redis_client.set.assert_not_called()
def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = None
fake_client = MagicMock()
fake_client.get.return_value = [{"id": "seg-1"}]
vector._get_client = MagicMock(return_value=fake_client)
exists = vector.text_exists("seg-1")
assert exists is True
vector._get_client.assert_called_once_with(None, False)
def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.full_text_query.return_value = [
SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7),
]
docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
assert docs[0].metadata["doc_id"] == "1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}}
def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
fake_client = MagicMock()
vector._get_client = MagicMock(return_value=fake_client)
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "matrixone"
assert result == ["seg-1"]
vector._get_client.assert_called_once_with(2, True)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
failing_client = MagicMock()
failing_client.create_full_text_index.side_effect = RuntimeError("boom")
monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client))
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
assert client is failing_client
matrixone_module.redis_client.set.assert_not_called()
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
Document(page_content="b", metadata={"document_id": "d-2"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
# For current prod code, only docs with metadata get ids, so only two ids
assert ids == ["doc-a", "generated-uuid"]
vector.client.insert.assert_called_once()
insert_kwargs = vector.client.insert.call_args.kwargs
# All lists passed to insert should be the same length
texts = insert_kwargs["texts"]
embeddings = insert_kwargs["embeddings"]
metadatas = insert_kwargs["metadatas"]
ids_insert = insert_kwargs["ids"]
assert len(texts) == len(embeddings) == len(metadatas) == len(docs)
# ids may be shorter than docs for current prod code, but should match number of docs with metadata
assert ids_insert == ["doc-a", "generated-uuid"]
def test_delete_and_metadata_methods(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")]
vector.delete_by_ids([])
vector.client.delete.assert_not_called()
vector.delete_by_ids(["seg-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
vector.delete()
assert ids == ["seg-1", "seg-2"]
assert vector.client.delete.call_count == 3
def test_search_by_vector_builds_documents(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.query.return_value = [
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}),
]
docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert docs[0].page_content == "doc-a"
assert docs[1].metadata["doc_id"] == "2"
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
factory = matrixone_module.MatrixoneVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001)
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2")
with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -1,18 +1,414 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
from core.rag.models.document import Document
def test_default_value():
def _build_fake_pymilvus_modules():
pymilvus = types.ModuleType("pymilvus")
pymilvus.__path__ = []
pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client")
pymilvus_orm = types.ModuleType("pymilvus.orm")
pymilvus_orm.__path__ = []
pymilvus_orm_types = types.ModuleType("pymilvus.orm.types")
class MilvusError(Exception):
pass
class MilvusClient:
def __init__(self, **kwargs):
self.init_kwargs = kwargs
self.has_collection = MagicMock(return_value=False)
self.describe_collection = MagicMock(
return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]}
)
self.get_server_version = MagicMock(return_value="2.5.0")
self.insert = MagicMock(return_value=[1])
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.drop_collection = MagicMock()
self.search = MagicMock(return_value=[[]])
self.create_collection = MagicMock()
class IndexParams:
def __init__(self):
self.indexes = []
def add_index(self, **kwargs):
self.indexes.append(kwargs)
class DataType:
JSON = "JSON"
VARCHAR = "VARCHAR"
INT64 = "INT64"
SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR"
FLOAT_VECTOR = "FLOAT_VECTOR"
class FieldSchema:
def __init__(self, name, dtype, **kwargs):
self.name = name
self.dtype = dtype
self.kwargs = kwargs
class CollectionSchema:
def __init__(self, fields):
self.fields = fields
self.functions = []
def add_function(self, func):
self.functions.append(func)
class FunctionType:
BM25 = "BM25"
class Function:
def __init__(self, **kwargs):
self.kwargs = kwargs
def infer_dtype_bydata(_value):
return DataType.FLOAT_VECTOR
pymilvus.MilvusException = MilvusError
pymilvus.MilvusClient = MilvusClient
pymilvus.IndexParams = IndexParams
pymilvus.CollectionSchema = CollectionSchema
pymilvus.DataType = DataType
pymilvus.FieldSchema = FieldSchema
pymilvus.Function = Function
pymilvus.FunctionType = FunctionType
pymilvus_milvus_client.IndexParams = IndexParams
pymilvus_orm.types = pymilvus_orm_types
pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata
# Attach submodules for dotted imports
pymilvus.milvus_client = pymilvus_milvus_client
pymilvus.orm = pymilvus_orm
return {
"pymilvus": pymilvus,
"pymilvus.milvus_client": pymilvus_milvus_client,
"pymilvus.orm": pymilvus_orm,
"pymilvus.orm.types": pymilvus_orm_types,
}
@pytest.fixture
def milvus_module(monkeypatch):
for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.milvus.milvus_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"uri": "http://localhost:19530",
"user": "root",
"password": "Milvus",
"database": "default",
"enable_hybrid_search": False,
"analyzer_params": None,
}
values.update(overrides)
return module.MilvusConfig.model_validate(values)
def test_config_validation_and_defaults(milvus_module):
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig.model_validate(config)
milvus_module.MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig.model_validate(valid_config)
config = milvus_module.MilvusConfig.model_validate(valid_config)
assert config.database == "default"
token_config = milvus_module.MilvusConfig.model_validate(
{"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"}
)
assert token_config.token == "token-value"
def test_config_to_milvus_params(milvus_module):
config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
params = config.to_milvus_params()
assert params["uri"] == "http://localhost:19530"
assert params["db_name"] == "default"
assert params["analyzer_params"] == '{"tokenizer":"standard"}'
def test_init_client_supports_token_and_user_password(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
token_client = vector._init_client(
milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"})
)
assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"}
user_client = vector._init_client(_config(milvus_module))
assert user_client.init_kwargs["uri"] == "http://localhost:19530"
assert user_client.init_kwargs["user"] == "root"
assert user_client.init_kwargs["password"] == "Milvus"
def test_init_loads_fields_when_collection_exists(milvus_module):
client = milvus_module.MilvusClient(uri="http://localhost:19530")
client.has_collection.return_value = True
client.describe_collection.return_value = {
"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}]
}
with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client):
with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False):
vector = milvus_module.MilvusVector("collection_1", _config(milvus_module))
assert "id" not in vector._fields
assert "content" in vector._fields
def test_load_collection_fields_from_argument_and_remote(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._client = MagicMock()
vector._collection_name = "collection_1"
vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]}
vector._load_collection_fields(["id", "metadata"])
assert vector._fields == ["metadata"]
vector._load_collection_fields()
assert vector._fields == ["content"]
def test_check_hybrid_search_support_branches(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(enable_hybrid_search=False)
assert vector._check_hybrid_search_support() is False
vector._client_config = SimpleNamespace(enable_hybrid_search=True)
vector._client.get_server_version.return_value = "Zilliz Cloud 2.4"
assert vector._check_hybrid_search_support() is True
vector._client.get_server_version.return_value = "2.5.1"
assert vector._check_hybrid_search_support() is True
vector._client.get_server_version.return_value = "2.4.9"
assert vector._check_hybrid_search_support() is False
vector._client.get_server_version.side_effect = RuntimeError("boom")
assert vector._check_hybrid_search_support() is False
def test_get_type_and_create_delegate(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [SimpleNamespace(page_content="hello", metadata=None)]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "milvus"
vector.create_collection.assert_called_once()
create_args = vector.create_collection.call_args.args
assert create_args[0] == [[0.1, 0.2]]
assert create_args[1] == [{}]
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_batches_and_raises_milvus_exception(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.insert.side_effect = [["id-1"], ["id-2"]]
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)]
embeddings = [[0.1, 0.2] for _ in range(1001)]
ids = vector.add_texts(docs, embeddings)
assert ids == ["id-1", "id-2"]
assert vector._client.insert.call_count == 2
vector._client.insert.side_effect = milvus_module.MilvusException("insert failed")
with pytest.raises(milvus_module.MilvusException):
vector.add_texts([Document(page_content="x", metadata={})], [[0.1]])
def test_get_ids_and_delete_methods(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.query.return_value = [{"id": 1}, {"id": 2}]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2]
vector._client.query.return_value = []
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector._client.has_collection.return_value = True
vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102])
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102])
vector._client.delete.reset_mock()
vector._client.query.return_value = [{"id": 11}, {"id": 12}]
vector.delete_by_ids(["doc-a", "doc-b"])
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12])
vector._client.has_collection.return_value = True
vector.delete()
vector._client.drop_collection.assert_called_once_with("collection_1", None)
def test_text_exists_and_field_exists(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._fields = ["content", "metadata"]
vector._client = MagicMock()
vector._client.has_collection.return_value = False
assert vector.text_exists("doc-1") is False
vector._client.has_collection.return_value = True
vector._client.query.return_value = [{"id": 1}]
assert vector.text_exists("doc-1") is True
vector._client.query.return_value = []
assert vector.text_exists("doc-1") is False
assert vector.field_exists("content") is True
assert vector.field_exists("unknown") is False
def test_process_search_results_and_search_methods(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._fields = ["content", "metadata", "sparse_vector"]
processed = vector._process_search_results(
[
[
{"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9},
{"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2},
]
],
[milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY],
score_threshold=0.5,
)
assert len(processed) == 1
assert processed[0].metadata["score"] == 0.9
vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]]
vector._process_search_results = MagicMock(return_value=["doc"])
docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1)
assert docs == ["doc"]
assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]'
vector._hybrid_search_enabled = False
assert vector.search_by_full_text("query") == []
vector._hybrid_search_enabled = True
vector._fields = []
assert vector.search_by_full_text("query") == []
vector._fields = [milvus_module.Field.SPARSE_VECTOR]
vector._process_search_results = MagicMock(return_value=["full-text-doc"])
full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2)
assert full_text_docs == ["full-text-doc"]
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._consistency_level = "Session"
vector._client_config = _config(milvus_module)
vector._hybrid_search_enabled = False
vector._client = MagicMock()
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
vector._client.has_collection.return_value = True
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
milvus_module.redis_client.set.assert_called()
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._consistency_level = "Session"
vector._client = MagicMock()
vector._client.has_collection.return_value = False
vector._load_collection_fields = MagicMock()
vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
vector._hybrid_search_enabled = True
vector.create_collection(
embeddings=[[0.1, 0.2]],
metadatas=[{"doc_id": "1"}],
index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}},
)
call_kwargs = vector._client.create_collection.call_args.kwargs
schema = call_kwargs["schema"]
index_params_obj = call_kwargs["index_params"]
field_names = [f.name for f in schema.fields]
assert milvus_module.Field.SPARSE_VECTOR in field_names
assert len(schema.functions) == 1
assert len(index_params_obj.indexes) == 2
assert call_kwargs["consistency_level"] == "Session"
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
factory = milvus_module.MilvusVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True)
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}')
with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,230 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_clickhouse_connect_module():
clickhouse_connect = types.ModuleType("clickhouse_connect")
class QueryResult:
def __init__(self, rows=None, named_rows=None):
self.row_count = len(rows or [])
self.result_rows = rows or []
self._named_rows = named_rows or []
def named_results(self):
return self._named_rows
class Client:
def __init__(self):
self.command = MagicMock()
self.query = MagicMock(return_value=QueryResult())
client = Client()
def get_client(**_kwargs):
return client
clickhouse_connect.get_client = get_client
clickhouse_connect.QueryResult = QueryResult
clickhouse_connect._fake_client = client
return clickhouse_connect
@pytest.fixture
def myscale_module(monkeypatch):
fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
import core.rag.datasource.vdb.myscale.myscale_vector as module
return importlib.reload(module)
def _config(module):
return module.MyScaleConfig(
host="localhost",
port=8123,
user="default",
password="",
database="dify",
fts_params="",
)
def test_escape_str_replaces_backslash_and_quote(myscale_module):
escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special")
assert escaped == "text with special"
def test_search_raises_for_invalid_top_k(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0)
def test_search_builds_where_clause_for_cosine_threshold(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__(
named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}]
)
docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2)
assert len(docs) == 1
sql = vector._client.query.call_args.args[0]
assert "WHERE dist < 0.8" in sql
def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.command.reset_mock()
vector.delete_by_ids([])
vector._client.command.assert_not_called()
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
factory = myscale_module.MyScaleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123)
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "")
with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None
def test_init_and_get_type_set_expected_defaults(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
assert vector.get_type() == "myscale"
assert vector._vec_order == myscale_module.SortOrder.ASC
vector._client.command.assert_called_with("SET allow_experimental_object_type=1")
def test_create_calls_create_collection_and_add_texts(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["seg-1"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once()
def test_create_collection_builds_expected_sql(myscale_module):
config = myscale_module.MyScaleConfig(
host="localhost",
port=8123,
user="default",
password="",
database="dify",
fts_params="tokenizer=unicode",
)
vector = myscale_module.MyScaleVector("collection_1", config)
vector._client.command.reset_mock()
vector._create_collection(3)
assert vector._client.command.call_count == 2
sql = vector._client.command.call_args_list[1].args[0]
assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql
assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
Document(page_content="text-2", metadata={"document_id": "d-2"}),
SimpleNamespace(page_content="text-3", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
sql = vector._client.command.call_args.args[0]
assert "INSERT INTO dify.collection_1" in sql
assert "te xt 1" in sql
def test_text_exists_and_metadata_operations(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)])
assert vector.text_exists("id-1") is True
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector._client.command.call_count >= 2
def test_search_delegation_methods(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._search = MagicMock(return_value=["result"])
result_vector = vector.search_by_vector([0.1, 0.2], top_k=2)
result_text = vector.search_by_full_text("hello", top_k=2)
assert result_vector == ["result"]
assert result_text == ["result"]
assert vector._search.call_count == 2
def test_search_with_document_filter_and_exception(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = SimpleNamespace(
named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}]
)
docs = vector._search(
"distance(vector, [0.1])",
myscale_module.SortOrder.ASC,
top_k=2,
document_ids_filter=["doc-1", "doc-2"],
)
assert len(docs) == 1
sql = vector._client.query.call_args.args[0]
assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql
vector._client.query.side_effect = RuntimeError("boom")
assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == []
def test_delete_drops_table(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.command.reset_mock()
vector.delete()
vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1")

View File

@ -0,0 +1,553 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from core.rag.models.document import Document
def _build_fake_pyobvector_module():
pyobvector = types.ModuleType("pyobvector")
class VECTOR:
def __init__(self, dim):
self.dim = dim
def l2_distance(*_args, **_kwargs):
return "l2"
def cosine_distance(*_args, **_kwargs):
return "cosine"
def inner_product(*_args, **_kwargs):
return "inner_product"
class ObVecClient:
def __init__(self, **_kwargs):
self.metadata_obj = SimpleNamespace(tables={})
self.engine = MagicMock()
self.check_table_exists = MagicMock(return_value=False)
self.perform_raw_text_sql = MagicMock()
self.prepare_index_params = MagicMock()
self.create_table_with_index_params = MagicMock()
self.refresh_metadata = MagicMock()
self.insert = MagicMock()
self.refresh_index = MagicMock()
self.get = MagicMock()
self.delete = MagicMock()
self.set_ob_hnsw_ef_search = MagicMock()
self.ann_search = MagicMock(return_value=[])
self.drop_table_if_exist = MagicMock()
pyobvector.VECTOR = VECTOR
pyobvector.ObVecClient = ObVecClient
pyobvector.l2_distance = l2_distance
pyobvector.cosine_distance = cosine_distance
pyobvector.inner_product = inner_product
return pyobvector
@pytest.fixture
def oceanbase_module(monkeypatch):
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module
return importlib.reload(module)
def _config(module):
return module.OceanBaseVectorConfig(
host="127.0.0.1",
port=2881,
user="root",
password="secret",
database="test",
enable_hybrid_search=True,
batch_size=10,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OCEANBASE_VECTOR_HOST is required"),
("port", 0, "config OCEANBASE_VECTOR_PORT is required"),
("user", "", "config OCEANBASE_VECTOR_USER is required"),
("database", "", "config OCEANBASE_VECTOR_DATABASE is required"),
],
)
def test_oceanbase_config_validation(oceanbase_module, field, value, message):
values = _config(oceanbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
oceanbase_module.OceanBaseVectorConfig.model_validate(values)
def test_init_rejects_invalid_collection_name(oceanbase_module):
with pytest.raises(ValueError, match="Invalid collection name"):
oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module))
def test_distance_to_score_for_supported_metrics(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="l2")
assert vector._distance_to_score(3.0) == pytest.approx(0.25)
vector._config = SimpleNamespace(metric_type="cosine")
assert vector._distance_to_score(0.2) == pytest.approx(0.8)
vector._config = SimpleNamespace(metric_type="inner_product")
assert vector._distance_to_score(-0.2) == pytest.approx(0.2)
def test_get_distance_func_raises_for_unknown_metric(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="manhattan")
with pytest.raises(ValueError, match="Unsupported metric_type"):
vector._get_distance_func()
def test_process_search_results_handles_json_and_score_threshold(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
rows = [
("doc-1", '{"doc_id":"1"}', 0.9),
("doc-2", "not-json", 0.8),
("doc-3", {"doc_id": "3"}, 0.3),
]
docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank")
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "1"
assert docs[0].metadata["rank"] == 0.9
assert docs[1].metadata["rank"] == 0.8
def test_search_by_vector_validates_document_id_format(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hnsw_ef_search = -1
vector._config = SimpleNamespace(metric_type="cosine")
vector._client = MagicMock()
with pytest.raises(ValueError, match="Invalid document ID format"):
vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"])
def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._hybrid_search_enabled = False
vector._collection_name = "collection_1"
assert vector.search_by_full_text("query") == []
def test_check_hybrid_search_support_uses_version_comment(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(enable_hybrid_search=True)
vector._client = MagicMock()
cursor = MagicMock()
cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",)
vector._client.perform_raw_text_sql.return_value = cursor
assert vector._check_hybrid_search_support() is True
cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",)
assert vector._check_hybrid_search_support() is False
def test_init_get_type_and_field_loading(oceanbase_module):
config = _config(oceanbase_module)
config.enable_hybrid_search = False
table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")])
fake_client = oceanbase_module.ObVecClient()
fake_client.check_table_exists.return_value = True
fake_client.metadata_obj.tables = {"collection_1": table}
with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client):
vector = oceanbase_module.OceanBaseVector("collection_1", config)
assert vector.get_type() == "oceanbase"
assert vector.field_exists("text") is True
def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._fields = []
vector._client = MagicMock()
vector._client.metadata_obj.tables = {}
vector._load_collection_fields()
assert vector._fields == []
vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))}
vector._load_collection_fields()
assert vector._fields == []
def test_create_delegates_to_collection_and_insert(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector._vec_dim == 2
vector._create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = False
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection()
vector._client.check_table_exists.assert_not_called()
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
vector._client.check_table_exists.return_value = True
vector._create_collection()
vector.delete.assert_not_called()
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 3
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
None,
None,
]
index_params = MagicMock()
vector._client.prepare_index_params.return_value = index_params
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._create_collection()
vector.delete.assert_called_once()
vector._client.create_table_with_index_params.assert_called_once()
index_params.add_index.assert_called_once()
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
oceanbase_module.redis_client.set.assert_called_once()
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.prepare_index_params.return_value = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._client.perform_raw_text_sql.return_value = []
with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"):
vector._create_collection()
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "0"]],
RuntimeError("no privilege"),
]
with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"):
vector._create_collection()
vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]]
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid")
with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"):
vector._create_collection()
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.prepare_index_params.return_value = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
RuntimeError("fulltext failed"),
]
with pytest.raises(Exception, match="Failed to add fulltext index"):
vector._create_collection()
vector._hybrid_search_enabled = False
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
SQLAlchemyError("metadata index failed"),
]
vector._create_collection()
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
def test_check_hybrid_search_support_false_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(enable_hybrid_search=False)
vector._client = MagicMock()
assert vector._check_hybrid_search_support() is False
vector._config = SimpleNamespace(enable_hybrid_search=True)
vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom")
assert vector._check_hybrid_search_support() is False
def test_add_texts_batches_refresh_and_exceptions(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2)
vector._client = MagicMock()
vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
Document(page_content="c", metadata={"doc_id": "id-3"}),
]
vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert vector._client.insert.call_count == 2
vector._client.refresh_index.assert_called_once()
vector._client.insert.reset_mock()
vector._client.refresh_index.reset_mock()
vector._client.insert.side_effect = RuntimeError("insert failed")
with pytest.raises(Exception, match="Failed to insert batch"):
vector.add_texts([docs[0]], [[0.1]])
vector._client.insert.side_effect = None
vector._client.insert.return_value = None
vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed")
vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1)
vector._get_uuids.return_value = ["id-1"]
vector.add_texts([docs[0]], [[0.1]])
def test_text_exists_and_delete_by_ids(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.get.return_value = SimpleNamespace(rowcount=1)
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to check text existence"):
vector.text_exists("id-1")
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector._client.delete.side_effect = None
vector.delete_by_ids(["id-1"])
vector._client.delete.assert_called_once()
vector._client.delete.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to delete documents"):
vector.delete_by_ids(["id-1"])
def test_get_ids_and_delete_by_metadata_field(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
execute_result = [("id-1",), ("id-2",)]
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value = execute_result
vector._client.engine.connect.return_value = conn
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
assert ids == ["id-1", "id-2"]
with pytest.raises(Exception, match="Failed to query documents by metadata field"):
vector.get_ids_by_metadata_field("bad key!", "doc-1")
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
vector.delete_by_ids.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_not_called()
def test_search_by_full_text_paths(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hybrid_search_enabled = True
vector.field_exists = MagicMock(return_value=False)
assert vector.search_by_full_text("query") == []
vector.field_exists.return_value = True
vector._client = MagicMock()
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = tx
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)]
vector._client.engine.connect.return_value = conn
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
with pytest.raises(Exception, match="Full-text search failed"):
vector.search_by_full_text("query", top_k=0)
def test_search_by_vector_paths(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hnsw_ef_search = -1
vector._config = SimpleNamespace(metric_type="cosine")
vector._client = MagicMock()
vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)]
vector._process_search_results = MagicMock(return_value=["doc"])
docs = vector.search_by_vector(
[0.1, 0.2],
ef_search=10,
top_k=3,
score_threshold=0.1,
document_ids_filter=["good_id"],
)
assert docs == ["doc"]
vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10)
with pytest.raises(ValueError, match="Invalid score_threshold parameter"):
vector.search_by_vector([0.1], score_threshold="x")
vector._client.ann_search.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Vector search failed"):
vector.search_by_vector([0.1], score_threshold=0.1)
def test_get_distance_func_and_distance_to_score_errors(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="cosine")
assert vector._get_distance_func() is oceanbase_module.cosine_distance
vector._config = SimpleNamespace(metric_type="unknown")
with pytest.raises(ValueError, match="Unsupported metric_type"):
vector._distance_to_score(0.1)
def test_delete_success_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector.delete()
vector._client.drop_table_if_exist.assert_called_once_with("collection_1")
vector._client.drop_table_if_exist.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to delete collection"):
vector.delete()
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
factory = oceanbase_module.OceanBaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000)
with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].args[0] == "existing_collection"
assert vector_cls.call_args_list[1].args[0] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,400 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_psycopg2_modules():
psycopg2 = types.ModuleType("psycopg2")
psycopg2.__path__ = []
psycopg2_extras = types.ModuleType("psycopg2.extras")
psycopg2_pool = types.ModuleType("psycopg2.pool")
class SimpleConnectionPool:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.getconn = MagicMock()
self.putconn = MagicMock()
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
psycopg2_extras.execute_values = MagicMock()
psycopg2.pool = psycopg2_pool
psycopg2.extras = psycopg2_extras
return {
"psycopg2": psycopg2,
"psycopg2.pool": psycopg2_pool,
"psycopg2.extras": psycopg2_extras,
}
@pytest.fixture
def opengauss_module(monkeypatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.opengauss.opengauss as module
return importlib.reload(module)
def _config(module, *, enable_pq=False):
return module.OpenGaussConfig(
host="localhost",
port=6600,
user="postgres",
password="password",
database="dify",
min_connection=1,
max_connection=5,
enable_pq=enable_pq,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OPENGAUSS_HOST is required"),
("port", 0, "config OPENGAUSS_PORT is required"),
("user", "", "config OPENGAUSS_USER is required"),
("password", "", "config OPENGAUSS_PASSWORD is required"),
("database", "", "config OPENGAUSS_DATABASE is required"),
("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"),
("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"),
],
)
def test_opengauss_config_validation(opengauss_module, field, value, message):
values = _config(opengauss_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
opengauss_module.OpenGaussConfig.model_validate(values)
def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module):
values = _config(opengauss_module).model_dump()
values["min_connection"] = 6
values["max_connection"] = 5
with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"):
opengauss_module.OpenGaussConfig.model_validate(values)
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
assert vector.table_name == "embedding_collection_1"
assert vector.get_type() == "opengauss"
assert vector.pool is pool
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(1536)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("enable_pq=on" in sql for sql in executed_sql)
assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql)
opengauss_module.redis_client.set.assert_called_once()
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(3072)
cursor.execute.assert_not_called()
opengauss_module.redis_client.set.assert_called_once()
def test_search_by_vector_validates_top_k(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=0)
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
vector._get_cursor = MagicMock()
vector.delete_by_ids([])
vector._get_cursor.assert_not_called()
def test_get_cursor_closes_commits_and_returns_connection(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
pool = MagicMock()
conn = MagicMock()
cur = MagicMock()
pool.getconn.return_value = conn
conn.cursor.return_value = cur
vector.pool = pool
with vector._get_cursor() as got_cur:
assert got_cur is cur
cur.close.assert_called_once()
conn.commit.assert_called_once()
pool.putconn.assert_called_once_with(conn)
def test_create_calls_collection_insert_and_index(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
vector._create_index = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
vector._create_index.assert_called_once_with(2)
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
vector._get_cursor = MagicMock()
vector._create_index(1536)
vector._get_cursor.assert_not_called()
opengauss_module.redis_client.set.assert_not_called()
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(1536)
sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
cursor = MagicMock()
opengauss_module.psycopg2.extras.execute_values.reset_mock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
docs = [
Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}),
SimpleNamespace(page_content="text-2", metadata=None),
]
monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid")
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["seg-1"]
opengauss_module.psycopg2.extras.execute_values.assert_called_once()
def test_text_exists_and_get_by_ids(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("seg-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("seg-1") is True
docs = vector.get_by_ids(["seg-1", "seg-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
def test_delete_and_metadata_field_queries(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete_by_ids(["seg-1", "seg-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql)
assert any("meta->>%s = %s" in query for query in sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql)
def test_search_by_vector_and_full_text(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
({"doc_id": "1"}, "text-1", 0.1),
({"doc_id": "2"}, "text-2", 0.6),
]
)
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)])
full_docs = vector.search_by_full_text("hello world", top_k=2)
assert len(full_docs) == 1
assert full_docs[0].page_content == "full-text"
def test_search_by_full_text_validates_top_k(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("query", top_k=0)
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(1536)
cursor.execute.assert_not_called()
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(1536)
cursor.execute.assert_called_once()
opengauss_module.redis_client.set.assert_called_once()
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
factory = opengauss_module.OpenGaussFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False)
with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,360 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_opensearch_modules():
opensearchpy = types.ModuleType("opensearchpy")
opensearchpy_helpers = types.ModuleType("opensearchpy.helpers")
class BulkIndexError(Exception):
def __init__(self, errors):
super().__init__("bulk error")
self.errors = errors
class Urllib3AWSV4SignerAuth:
def __init__(self, credentials, region, service):
self.credentials = credentials
self.region = region
self.service = service
class Urllib3HttpConnection:
pass
class _IndicesClient:
def __init__(self):
self.exists = MagicMock(return_value=False)
self.create = MagicMock()
self.delete = MagicMock()
class OpenSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.indices = _IndicesClient()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.get = MagicMock()
helpers = SimpleNamespace(bulk=MagicMock())
opensearchpy.OpenSearch = OpenSearch
opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth
opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection
opensearchpy.helpers = helpers
opensearchpy_helpers.BulkIndexError = BulkIndexError
return {
"opensearchpy": opensearchpy,
"opensearchpy.helpers": opensearchpy_helpers,
}
@pytest.fixture
def opensearch_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.opensearch.opensearch_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 9200,
"secure": True,
"verify_certs": True,
"auth_method": "basic",
"user": "admin",
"password": "secret",
}
values.update(overrides)
return module.OpenSearchConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OPENSEARCH_HOST is required"),
("port", 0, "config OPENSEARCH_PORT is required"),
],
)
def test_config_validation_required_fields(opensearch_module, field, value, message):
values = _config(opensearch_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
values = {
"host": "localhost",
"port": 9200,
"secure": True,
"verify_certs": True,
"auth_method": "aws_managed_iam",
"user": "admin",
"password": "secret",
}
with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"):
opensearch_module.OpenSearchConfig.model_validate(values)
values = _config(opensearch_module).model_dump()
values["OPENSEARCH_SECURE"] = False
values["OPENSEARCH_VERIFY_CERTS"] = True
with pytest.raises(ValidationError, match="verify_certs=True requires secure"):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
class _Session:
def get_credentials(self):
return "creds"
boto3 = types.ModuleType("boto3")
boto3.Session = _Session
monkeypatch.setitem(sys.modules, "boto3", boto3)
config = _config(
opensearch_module,
auth_method="aws_managed_iam",
aws_region="us-east-1",
aws_service="es",
)
auth = config.create_aws_managed_iam_auth()
assert auth.credentials == "creds"
assert auth.region == "us-east-1"
assert auth.service == "es"
def test_to_opensearch_params_supports_basic_and_aws(opensearch_module):
basic_params = _config(opensearch_module).to_opensearch_params()
assert basic_params["http_auth"] == ("admin", "secret")
aws_config = _config(
opensearch_module,
auth_method="aws_managed_iam",
aws_region="us-west-2",
aws_service="es",
)
with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"):
aws_params = aws_config.to_opensearch_params()
assert aws_params["http_auth"] == "iam-auth"
def test_init_and_create_delegate_calls(opensearch_module):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "opensearch"
vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}])
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
Document(page_content="b", metadata={"doc_id": "2"}),
]
monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id"))
opensearch_module.helpers.bulk.reset_mock()
vector.add_texts(docs, [[0.1], [0.2]])
actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
assert len(actions) == 2
assert all("_id" in action for action in actions)
vector._client_config.aws_service = "aoss"
opensearch_module.helpers.bulk.reset_mock()
vector.add_texts(docs, [[0.3], [0.4]])
aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
assert all("_id" not in action for action in aoss_actions)
def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
vector._client.search.return_value = {"hits": {"hits": []}}
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
opensearch_module.helpers.bulk.reset_mock()
vector._client.indices.exists.return_value = False
vector.delete_by_ids(["doc-1"])
opensearch_module.helpers.bulk.assert_not_called()
vector._client.indices.exists.return_value = True
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None])
vector.delete_by_ids(["doc-1", "doc-2"])
opensearch_module.helpers.bulk.assert_called_once()
opensearch_module.helpers.bulk.reset_mock()
vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"])
opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError(
[{"delete": {"status": 404, "_id": "es-404"}}]
)
vector.delete_by_ids(["doc-404"])
assert opensearch_module.helpers.bulk.call_count == 1
opensearch_module.helpers.bulk.side_effect = None
def test_delete_and_text_exists(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True)
vector._client.get.return_value = {"_id": "id-1"}
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("not found")
assert vector.text_exists("id-1") is False
def test_search_by_vector_validates_and_builds_documents(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
with pytest.raises(ValueError, match="query_vector should be a list"):
vector.search_by_vector("not-a-list")
with pytest.raises(ValueError, match="should be floats"):
vector.search_by_vector([0.1, 1])
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
opensearch_module.Field.CONTENT_KEY: "doc-1",
opensearch_module.Field.METADATA_KEY: None,
},
"_score": 0.9,
},
{
"_source": {
opensearch_module.Field.CONTENT_KEY: "doc-2",
opensearch_module.Field.METADATA_KEY: {"doc_id": "2"},
},
"_score": 0.1,
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].page_content == "doc-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"])
query = vector._client.search.call_args.kwargs["body"]
assert "script_score" in query["query"]
def test_search_by_vector_reraises_client_error(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.search_by_vector([0.1, 0.2])
def test_search_by_full_text_and_filters(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
opensearch_module.Field.METADATA_KEY: {"doc_id": "1"},
opensearch_module.Field.VECTOR: [0.1],
opensearch_module.Field.CONTENT_KEY: "matched text",
}
},
]
}
}
docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].page_content == "matched text"
query = vector._client.search.call_args.kwargs["body"]
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock())
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1))
vector._client.indices.create.reset_mock()
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_called_once()
index_body = vector._client.indices.create.call_args.kwargs["body"]
assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2
opensearch_module.redis_client.set.assert_called()
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
factory = opensearch_module.OpenSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None)
with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,375 @@
import array
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_oracle_modules():
jieba = types.ModuleType("jieba")
jieba_posseg = types.ModuleType("jieba.posseg")
jieba_posseg.cut = MagicMock(return_value=[])
jieba.posseg = jieba_posseg
oracledb = types.ModuleType("oracledb")
oracledb_connection = types.ModuleType("oracledb.connection")
class Connection:
pass
oracledb_connection.Connection = Connection
oracledb.defaults = SimpleNamespace(fetch_lobs=True)
oracledb.DB_TYPE_VECTOR = object()
oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock()))
oracledb.connect = MagicMock()
return {
"jieba": jieba,
"jieba.posseg": jieba_posseg,
"oracledb": oracledb,
"oracledb.connection": oracledb_connection,
}
def _connection_with_cursor(cursor):
cursor_ctx = MagicMock()
cursor_ctx.__enter__.return_value = cursor
cursor_ctx.__exit__.return_value = None
connection = MagicMock()
connection.__enter__.return_value = connection
connection.__exit__.return_value = None
connection.cursor.return_value = cursor_ctx
return connection
@pytest.fixture
def oracle_module(monkeypatch):
for name, module in _build_fake_oracle_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.oracle.oraclevector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"user": "system",
"password": "oracle",
"dsn": "oracle:1521/freepdb1",
"is_autonomous": False,
}
values.update(overrides)
return module.OracleVectorConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("user", "", "config ORACLE_USER is required"),
("password", "", "config ORACLE_PASSWORD is required"),
("dsn", "", "config ORACLE_DSN is required"),
],
)
def test_oracle_config_validation_required_fields(oracle_module, field, value, message):
values = _config(oracle_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
oracle_module.OracleVectorConfig.model_validate(values)
def test_oracle_config_validation_autonomous_requirements(oracle_module):
with pytest.raises(ValidationError, match="config_dir is required"):
oracle_module.OracleVectorConfig.model_validate(
{"user": "u", "password": "p", "dsn": "d", "is_autonomous": True}
)
def test_init_and_get_type(oracle_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
assert vector.get_type() == "oracle"
assert vector.table_name == "embedding_collection_1"
assert vector.pool is pool
def test_numpy_converters_and_type_handlers(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64))
in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32))
in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8))
assert in_float64.typecode == "d"
assert in_float32.typecode == "f"
assert in_int8.typecode == "b"
cursor = MagicMock()
vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2)
cursor.var.assert_called_with(
oracle_module.oracledb.DB_TYPE_VECTOR,
arraysize=2,
inconverter=vector.numpy_converter_in,
)
metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR)
cursor.arraysize = 3
vector.output_type_handler(cursor, metadata)
cursor.var.assert_called_with(
metadata.type_code,
arraysize=3,
outconverter=vector.numpy_converter_out,
)
out_int8 = vector.numpy_converter_out(array.array("b", [1]))
assert out_int8.dtype == numpy.int8
out_float32 = vector.numpy_converter_out(array.array("f", [1.0]))
assert out_float32.dtype == numpy.float32
out_float64 = vector.numpy_converter_out(array.array("d", [1.0]))
assert out_float64.dtype == numpy.float64
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
connect = MagicMock(return_value="connection")
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.config = _config(oracle_module)
assert vector._get_connection() == "connection"
connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1")
vector.config = _config(
oracle_module,
is_autonomous=True,
config_dir="/wallet",
wallet_location="/wallet",
wallet_password="pw",
)
vector._get_connection()
assert connect.call_args.kwargs["config_dir"] == "/wallet"
assert connect.call_args.kwargs["wallet_location"] == "/wallet"
def test_create_delegates_collection_and_insert(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["seg-1"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
vector.output_type_handler = MagicMock()
cursor = MagicMock()
cursor.execute.side_effect = [None, RuntimeError("insert failed")]
connection = _connection_with_cursor(cursor)
vector._get_connection = MagicMock(return_value=connection)
monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
assert cursor.execute.call_count == 2
assert connection.commit.call_count >= 1
connection.close.assert_called()
def test_text_exists_and_get_by_ids(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.pool = MagicMock()
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.pool.release.assert_called_once()
assert vector.get_by_ids([]) == []
def test_delete_methods(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
vector.delete_by_ids([])
vector._get_connection.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql)
assert any("JSON_VALUE(meta" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_search_by_vector_with_threshold_and_filter(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
vector.output_type_handler = MagicMock()
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)])
connection = _connection_with_cursor(cursor)
vector._get_connection = MagicMock(return_value=connection)
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=0,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
sql = cursor.execute.call_args.args[0]
assert "fetch first 4 rows only" in sql
assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql
def _fake_nltk_module(*, missing_data=False):
nltk = types.ModuleType("nltk")
nltk_corpus = types.ModuleType("nltk.corpus")
class _Data:
@staticmethod
def find(_path):
if missing_data:
raise LookupError("missing")
return True
nltk.data = _Data()
nltk.word_tokenize = lambda text: text.split()
nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"])
return nltk, nltk_corpus
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])])
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("", "nr"), ("", "nr"), ("", "x")]))
zh_docs = vector.search_by_full_text("张三", top_k=2)
assert len(zh_docs) == 1
zh_params = cursor.execute.call_args.args[1]
assert zh_params["kk"] == "张三"
nltk, nltk_corpus = _fake_nltk_module(missing_data=False)
monkeypatch.setitem(sys.modules, "nltk", nltk)
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])])
en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"])
assert len(en_docs) == 1
en_sql = cursor.execute.call_args.args[0]
en_params = cursor.execute.call_args.args[1]
assert "fetch first 5 rows only" in en_sql
assert "doc_id_0" in en_params
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector._get_connection = MagicMock()
empty_result = vector.search_by_full_text("")
assert empty_result[0].page_content == ""
nltk, nltk_corpus = _fake_nltk_module(missing_data=True)
monkeypatch.setitem(sys.modules, "nltk", nltk)
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
with pytest.raises(LookupError, match="required NLTK data package"):
vector.search_by_full_text("english query")
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock())
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector._collection_name = "collection_1"
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(2)
cursor.execute.assert_not_called()
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(2)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql)
oracle_module.redis_client.set.assert_called_once()
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
factory = oracle_module.OracleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False)
with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,317 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.types import UserDefinedType
from core.rag.models.document import Document
def _build_fake_pgvecto_modules():
pgvecto_rs = types.ModuleType("pgvecto_rs")
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
class VECTOR(UserDefinedType):
def __init__(self, dim):
self.dim = dim
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
return {
"pgvecto_rs": pgvecto_rs,
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
}
class _FakeSessionContext:
def __init__(self, calls, execute_results=None):
self.calls = calls
self.execute_results = execute_results or []
self.execute = MagicMock(side_effect=self._execute_side_effect)
self.commit = MagicMock()
def _execute_side_effect(self, *args, **kwargs):
self.calls.append((args, kwargs))
if self.execute_results:
return self.execute_results.pop(0)
return MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(calls, execute_results=None):
def _session(_client):
return _FakeSessionContext(calls=calls, execute_results=execute_results)
return _session
@pytest.fixture
def pgvecto_module(monkeypatch):
for name, module in _build_fake_pgvecto_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module
import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module
return importlib.reload(module), importlib.reload(collection_module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "secret",
"database": "postgres",
}
values.update(overrides)
return module.PgvectoRSConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config PGVECTO_RS_HOST is required"),
("port", 0, "config PGVECTO_RS_PORT is required"),
("user", "", "config PGVECTO_RS_USER is required"),
("password", "", "config PGVECTO_RS_PASSWORD is required"),
("database", "", "config PGVECTO_RS_DATABASE is required"),
],
)
def test_pgvecto_config_validation(pgvecto_module, field, value, message):
module, _ = pgvecto_module
values = _config(module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
module.PgvectoRSConfig.model_validate(values)
def test_collection_base_has_expected_annotations(pgvecto_module):
_, collection_module = pgvecto_module
annotations = collection_module.CollectionORM.__annotations__
assert {"id", "text", "meta", "vector"} <= set(annotations)
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == module.VectorType.PGVECTO_RS
module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres")
assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls)
vector.create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(module.redis_client, "set", MagicMock())
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection(3)
assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None))
vector.create_collection(3)
assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls)
module.redis_client.set.assert_called()
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
runtime_calls = []
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
class _InsertBuilder:
def __init__(self, table):
self.table = table
def values(self, **kwargs):
return ("insert", kwargs)
monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table))
monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"]))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
Document(page_content="b", metadata={"doc_id": "2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["uuid-1", "uuid-2"]
assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0])
monkeypatch.setattr(
module,
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
monkeypatch.setattr(
module,
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]),
)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(
module,
"Session",
_session_factory(
runtime_calls,
execute_results=[
SimpleNamespace(fetchall=lambda: [("row-id-1",)]),
MagicMock(),
],
),
)
vector.delete_by_ids(["doc-1"])
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
vector.delete()
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
runtime_calls = []
monkeypatch.setattr(
module,
"Session",
_session_factory(
runtime_calls,
execute_results=[
SimpleNamespace(fetchall=lambda: [("id-1",)]),
SimpleNamespace(fetchall=lambda: []),
],
),
)
assert vector.text_exists("doc-1") is True
assert vector.text_exists("doc-1") is False
class _DistanceExpr:
def label(self, _name):
return self
class _VectorColumn:
def op(self, _operator, return_type=None):
def _call(_query_vector):
return _DistanceExpr()
return _call
class _MetaFilter:
def in_(self, values):
return ("in", values)
class _MetaColumn:
def __getitem__(self, _item):
return _MetaFilter()
class _Stmt:
def __init__(self):
self.where_called = False
def limit(self, _value):
return self
def order_by(self, _value):
return self
def where(self, _value):
self.where_called = True
return self
stmt = _Stmt()
monkeypatch.setattr(module, "select", lambda *_args: stmt)
vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn())
rows = [
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
]
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert stmt.where_called is True
assert vector.search_by_full_text("hello") == []
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
factory = module.PGVectoRSFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432)
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres")
embeddings = MagicMock()
embeddings.embed_query.return_value = [0.1, 0.2, 0.3]
with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings)
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings)
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -1,16 +1,19 @@
import unittest
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
from core.rag.models.document import Document
class TestPGVector(unittest.TestCase):
def setUp(self):
class TestPGVector:
def setup_method(self, method):
self.config = PGVectorConfig(
host="localhost",
port=5432,
@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()
def test_create_delegates_collection_creation_and_insert():
vector = PGVector.__new__(PGVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["doc-a"])
docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["doc-a"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid")
execute_values = MagicMock()
monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values)
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
execute_values.assert_called_once()
def test_text_get_and_delete_methods():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("meta->>%s = %s" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete_by_ids([])
cursor.execute.assert_not_called()
class _UndefinedTableError(Exception):
pass
monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError)
cursor.execute.side_effect = _UndefinedTableError("missing")
vector.delete_by_ids(["doc-1"])
cursor.execute.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.delete_by_ids(["doc-1"])
def test_search_by_vector_supports_filter_and_threshold():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1], top_k=0)
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
sql = cursor.execute.call_args.args[0]
assert "meta->>'document_id' in ('d-1')" in sql
def test_search_by_full_text_branches_for_bigm_and_standard():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("hello", top_k=0)
vector.pg_bigm = False
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.7)
standard_sql = cursor.execute.call_args.args[0]
assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql
cursor.execute.reset_mock()
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)])
vector.pg_bigm = True
vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"])
assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0]
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
factory = pgvector_module.PGVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False)
with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,269 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_psycopg2_modules():
psycopg2 = types.ModuleType("psycopg2")
psycopg2.__path__ = []
psycopg2_extras = types.ModuleType("psycopg2.extras")
psycopg2_pool = types.ModuleType("psycopg2.pool")
class SimpleConnectionPool:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.getconn = MagicMock()
self.putconn = MagicMock()
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
psycopg2_extras.execute_values = MagicMock()
psycopg2.pool = psycopg2_pool
psycopg2.extras = psycopg2_extras
return {
"psycopg2": psycopg2,
"psycopg2.pool": psycopg2_pool,
"psycopg2.extras": psycopg2_extras,
}
@pytest.fixture
def vastbase_module(monkeypatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module
return importlib.reload(module)
def _config(module):
return module.VastbaseVectorConfig(
host="localhost",
port=5432,
user="dify",
password="secret",
database="dify",
min_connection=1,
max_connection=5,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config VASTBASE_HOST is required"),
("port", 0, "config VASTBASE_PORT is required"),
("user", "", "config VASTBASE_USER is required"),
("password", "", "config VASTBASE_PASSWORD is required"),
("database", "", "config VASTBASE_DATABASE is required"),
("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"),
("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"),
],
)
def test_vastbase_config_validation(vastbase_module, field, value, message):
values = _config(vastbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
vastbase_module.VastbaseVectorConfig.model_validate(values)
def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"):
vastbase_module.VastbaseVectorConfig.model_validate(
{
"host": "localhost",
"port": 5432,
"user": "dify",
"password": "secret",
"database": "dify",
"min_connection": 6,
"max_connection": 5,
}
)
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
conn = MagicMock()
cur = MagicMock()
pool.getconn.return_value = conn
conn.cursor.return_value = cur
vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module))
assert vector.get_type() == "vastbase"
assert vector.table_name == "embedding_collection_1"
with vector._get_cursor() as got_cur:
assert got_cur is cur
cur.close.assert_called_once()
conn.commit.assert_called_once()
pool.putconn.assert_called_once_with(conn)
def test_create_and_add_texts(vastbase_module, monkeypatch):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
vector._create_collection = MagicMock()
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
vastbase_module.psycopg2.extras.execute_values.assert_called_once()
vector.add_texts = MagicMock(return_value=["doc-a"])
result = vector.create(docs, [[0.1], [0.2], [0.3]])
vector._create_collection.assert_called_once_with(1)
assert result == ["doc-a"]
def test_text_get_delete_and_metadata_methods(vastbase_module):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.delete_by_ids([])
vector.delete_by_ids(["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql)
assert any("meta->>%s = %s" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_search_by_vector_and_full_text(vastbase_module):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
({"doc_id": "1"}, "text-1", 0.1),
({"doc_id": "2"}, "text-2", 0.8),
]
)
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=0)
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("hello", top_k=0)
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)])
full_docs = vector.search_by_full_text("hello world", top_k=2)
assert len(full_docs) == 1
assert full_docs[0].page_content == "full-text"
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock())
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector._collection_name = "collection_1"
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
cursor.execute.assert_not_called()
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(17000)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql)
cursor.execute.reset_mock()
vector._create_collection(3)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql)
vastbase_module.redis_client.set.assert_called()
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
factory = vastbase_module.VastbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432)
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1)
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5)
with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,328 @@
import importlib
import os
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_qdrant_modules():
qdrant_client = types.ModuleType("qdrant_client")
qdrant_http = types.ModuleType("qdrant_client.http")
qdrant_http_models = types.ModuleType("qdrant_client.http.models")
qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions")
qdrant_local_pkg = types.ModuleType("qdrant_client.local")
qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local")
class UnexpectedResponseError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
class FilterSelector:
def __init__(self, filter):
self.filter = filter
class HnswConfigDiff:
def __init__(self, **kwargs):
self.kwargs = kwargs
class TextIndexParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class PointStruct:
def __init__(self, **kwargs):
self.id = kwargs["id"]
self.vector = kwargs["vector"]
self.payload = kwargs["payload"]
class Filter:
def __init__(self, must=None):
self.must = must or []
class FieldCondition:
def __init__(self, key, match):
self.key = key
self.match = match
class MatchValue:
def __init__(self, value):
self.value = value
class MatchAny:
def __init__(self, any):
self.any = any
class MatchText:
def __init__(self, text):
self.text = text
class _Distance(UserDict):
def __getitem__(self, key):
return key
class QdrantClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[]))
self.create_collection = MagicMock()
self.create_payload_index = MagicMock()
self.upsert = MagicMock()
self.delete = MagicMock()
self.delete_collection = MagicMock()
self.retrieve = MagicMock(return_value=[])
self.search = MagicMock(return_value=[])
self.scroll = MagicMock(return_value=([], None))
class QdrantLocal(QdrantClient):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load = MagicMock()
qdrant_client.QdrantClient = QdrantClient
qdrant_http_models.FilterSelector = FilterSelector
qdrant_http_models.HnswConfigDiff = HnswConfigDiff
qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD")
qdrant_http_models.TextIndexParams = TextIndexParams
qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT")
qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL")
qdrant_http_models.VectorParams = VectorParams
qdrant_http_models.Distance = _Distance()
qdrant_http_models.PointStruct = PointStruct
qdrant_http_models.Filter = Filter
qdrant_http_models.FieldCondition = FieldCondition
qdrant_http_models.MatchValue = MatchValue
qdrant_http_models.MatchAny = MatchAny
qdrant_http_models.MatchText = MatchText
qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError
qdrant_http.models = qdrant_http_models
qdrant_local_mod.QdrantLocal = QdrantLocal
qdrant_local_pkg.qdrant_local = qdrant_local_mod
return {
"qdrant_client": qdrant_client,
"qdrant_client.http": qdrant_http,
"qdrant_client.http.models": qdrant_http_models,
"qdrant_client.http.exceptions": qdrant_http_exceptions,
"qdrant_client.local": qdrant_local_pkg,
"qdrant_client.local.qdrant_local": qdrant_local_mod,
}
@pytest.fixture
def qdrant_module(monkeypatch):
for name, module in _build_fake_qdrant_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.qdrant.qdrant_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"endpoint": "http://localhost:6333",
"api_key": "api-key",
"timeout": 20,
"root_path": "/tmp",
"grpc_port": 6334,
"prefer_grpc": False,
"replication_factor": 1,
"write_consistency_factor": 1,
}
values.update(overrides)
return module.QdrantConfig.model_validate(values)
def test_qdrant_config_to_params(qdrant_module):
url_params = _config(qdrant_module).to_qdrant_params().model_dump()
assert url_params["url"] == "http://localhost:6333"
assert url_params["verify"] is False
path_config = _config(qdrant_module, endpoint="path:storage")
assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage")
with pytest.raises(ValueError, match="Root path is not set"):
_config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params()
def test_init_and_basic_behaviour(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
assert vector.get_type() == qdrant_module.VectorType.QDRANT
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1"
docs = [Document(page_content="a", metadata={"doc_id": "a"})]
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once_with("collection_1", 1)
vector.add_texts.assert_called_once()
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock())
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection("collection_1", 3)
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None))
vector._client.get_collections.return_value = SimpleNamespace(collections=[])
vector.create_collection("collection_1", 3)
vector._client.create_collection.assert_called_once()
assert vector._client.create_payload_index.call_count == 4
qdrant_module.redis_client.set.assert_called_once()
# add_texts and generated batches
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.upsert.call_count == 1
payloads = qdrant_module.QdrantVector._build_payloads(
["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
)
assert payloads[0]["group_id"] == "g1"
with pytest.raises(ValueError, match="At least one of the texts is None"):
qdrant_module.QdrantVector._build_payloads(
[None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
)
def test_delete_and_exists_paths(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse
vector._client.delete.side_effect = unexpected(404)
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(404)
vector.delete()
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete()
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(404)
vector.delete_by_ids(["doc-1"])
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete_by_ids(["doc-1"])
vector._client.delete.side_effect = None
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")])
assert vector.text_exists("id-1") is False
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")])
vector._client.retrieve.return_value = [{"id": "id-1"}]
assert vector.text_exists("id-1") is True
def test_search_and_helper_methods(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
assert vector.search_by_vector([0.1], score_threshold=1.0) == []
vector._client.search.return_value = [
SimpleNamespace(payload=None, score=0.9, vector=[0.1]),
SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]),
]
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.8)
# full text search: keyword split, dedup and top_k limit
scroll_results = [
(
[
SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]),
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
],
None,
),
(
[
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
],
None,
),
]
vector._client.scroll.side_effect = scroll_results
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert vector.search_by_full_text(" ", top_k=2) == []
local_client = qdrant_module.QdrantLocal()
vector._client = local_client
vector._reload_if_needed()
local_client._load.assert_called_once()
doc = vector._document_from_scored_point(
SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]),
"page_content",
"metadata",
)
assert doc.page_content == "doc"
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
factory = qdrant_module.QdrantVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
collection_binding_id=None,
index_struct_dict=None,
index_struct=None,
)
monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root")))
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333")
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key")
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1)
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset.index_struct is not None
# collection binding lookup path
dataset.collection_binding_id = "binding-1"
dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}}
monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
qdrant_module.db.session.scalars = MagicMock(
return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION"))
)
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION"
qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None))
with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"):
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())

View File

@ -0,0 +1,303 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.types import UserDefinedType
from core.rag.models.document import Document
def _build_fake_relyt_modules():
pgvecto_rs = types.ModuleType("pgvecto_rs")
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
class VECTOR(UserDefinedType):
def __init__(self, dim):
self.dim = dim
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
return {
"pgvecto_rs": pgvecto_rs,
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
}
class _FakeSession:
def __init__(self, execute_result=None):
self.execute_result = execute_result or MagicMock(fetchall=lambda: [])
self.execute = MagicMock(return_value=self.execute_result)
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
@pytest.fixture
def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.relyt.relyt_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "secret",
"database": "relyt",
}
values.update(overrides)
return module.RelytConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config RELYT_HOST is required"),
("port", 0, "config RELYT_PORT is required"),
("user", "", "config RELYT_USER is required"),
("password", "", "config RELYT_PASSWORD is required"),
("database", "", "config RELYT_DATABASE is required"),
],
)
def test_relyt_config_validation(relyt_module, field, value, message):
values = _config(relyt_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
relyt_module.RelytConfig.model_validate(values)
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
engine = MagicMock()
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == relyt_module.VectorType.RELYT
assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt"
assert vector.embedding_dimension == 2
vector.create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock())
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.create_collection(3)
session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql)
assert any("CREATE INDEX" in sql for sql in executed_sql)
relyt_module.redis_client.set.assert_called_once()
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector._group_id = "group-1"
vector.client = MagicMock()
begin_ctx = MagicMock()
begin_ctx.__enter__.return_value = None
begin_ctx.__exit__.return_value = None
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.begin.return_value = begin_ctx
vector.client.connect.return_value = conn
monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"]))
docs = [
Document(page_content="a", metadata={"doc_id": "d-1"}),
Document(page_content="b", metadata={"doc_id": "d-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert conn.execute.call_count >= 1
first_insert_values = conn.execute.call_args.args[0].compile().params
assert "group_id" in str(first_insert_values)
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"]
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
# 1. delete_by_uuids: success and connect error
def test_delete_by_uuids_success_and_connect_error(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
with pytest.raises(ValueError, match="No ids provided"):
vector.delete_by_uuids(None)
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
begin_ctx = MagicMock()
begin_ctx.__enter__.return_value = None
begin_ctx.__exit__.return_value = None
conn.begin.return_value = begin_ctx
vector.client.connect.return_value = conn
assert vector.delete_by_uuids(["id-1"]) is True
vector.client.connect.side_effect = RuntimeError("boom")
assert vector.delete_by_uuids(["id-1"]) is False
# 2. delete_by_metadata_field calls delete_by_uuids
def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_uuids = MagicMock(return_value=True)
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_uuids.assert_called_once_with(["id-1"])
# 3. delete_by_ids translates to uuids
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.delete_by_uuids = MagicMock(return_value=True)
vector.delete_by_ids(["doc-1", "doc-2"])
vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"])
# 4. text_exists True
def test_text_exists_true(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.text_exists("doc-1") is True
# 5. text_exists False
def test_text_exists_false(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.text_exists("doc-1") is False
# 6. similarity_search_with_score_by_vector returns Documents and scores
def test_similarity_search_with_score_by_vector(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
result_rows = [
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8),
]
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value.fetchall.return_value = result_rows
vector.client.connect.return_value = conn
similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]})
assert len(similarities) == 2
assert similarities[0][0].page_content == "doc-a"
# 7. search_by_vector filters by score and ids
def test_search_by_vector_filters_by_score_and_ids(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
vector.similarity_search_with_score_by_vector = MagicMock(
return_value=[
(Document(page_content="a", metadata={"doc_id": "1"}), 0.1),
(Document(page_content="b", metadata={}), 0.9),
]
)
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert vector.search_by_full_text("query") == []
# 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.delete()
session.commit.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
factory = relyt_module.RelytVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432)
monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt")
with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,316 @@
import importlib
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_tablestore_module():
tablestore = types.ModuleType("tablestore")
class _BatchGetRowRequest:
def __init__(self):
self.items = []
def add(self, item):
self.items.append(item)
class _TableInBatchGetRowItem:
def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver):
self.table_name = table_name
self.rows_to_get = rows_to_get
self.columns_to_get = columns_to_get
class _Row:
def __init__(self, primary_key, attribute_columns=None):
self.primary_key = primary_key
self.attribute_columns = attribute_columns or []
class _Client:
def __init__(self, *_args):
self.list_table = MagicMock(return_value=[])
self.create_table = MagicMock()
self.list_search_index = MagicMock(return_value=[])
self.create_search_index = MagicMock()
self.delete_search_index = MagicMock()
self.delete_table = MagicMock()
self.put_row = MagicMock()
self.delete_row = MagicMock()
self.get_row = MagicMock(return_value=(None, None, None))
self.batch_get_row = MagicMock()
self.search = MagicMock()
tablestore.OTSClient = _Client
tablestore.BatchGetRowRequest = _BatchGetRowRequest
tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem
tablestore.Row = _Row
tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema)
tablestore.TableOptions = lambda: ("table_options",)
tablestore.CapacityUnit = lambda read, write: ("capacity", read, write)
tablestore.ReservedThroughput = lambda cap: ("reserved", cap)
tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs)
tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs)
tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas)
tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs)
tablestore.TermQuery = lambda key, value: ("term_query", key, value)
tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs)
tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs)
tablestore.TermsQuery = lambda key, values: ("terms_query", key, values)
tablestore.Sort = lambda **kwargs: ("sort", kwargs)
tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs)
tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs)
tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs)
tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD")
tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD")
tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32")
tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE")
tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX")
tablestore.SortOrder = SimpleNamespace(DESC="DESC")
return tablestore
@pytest.fixture
def tablestore_module(monkeypatch):
fake_module = _build_fake_tablestore_module()
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
import core.rag.datasource.vdb.tablestore.tablestore_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"access_key_id": "ak",
"access_key_secret": "sk",
"instance_name": "instance",
"endpoint": "endpoint",
"normalize_full_text_bm25_score": False,
}
values.update(overrides)
return module.TableStoreConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("access_key_id", "", "config ACCESS_KEY_ID is required"),
("access_key_secret", "", "config ACCESS_KEY_SECRET is required"),
("instance_name", "", "config INSTANCE_NAME is required"),
("endpoint", "", "config ENDPOINT is required"),
],
)
def test_tablestore_config_validation(tablestore_module, field, value, message):
values = _config(tablestore_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
tablestore_module.TableStoreConfig.model_validate(values)
def test_init_and_basic_delegation(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
assert vector.get_type() == tablestore_module.VectorType.TABLESTORE
assert vector._table_name == "collection_1"
assert vector._index_name == "collection_1_idx"
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]])
vector.create_collection([[0.1, 0.2]])
assert vector._create_collection.call_count == 2
def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
# get_by_ids
ok_item = SimpleNamespace(
is_ok=True,
row=SimpleNamespace(
attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)]
),
)
fail_item = SimpleNamespace(is_ok=False, row=None)
batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item])
vector._tablestore_client.batch_get_row.return_value = batch_resp
docs = vector.get_by_ids(["id-1"])
assert len(docs) == 1
assert docs[0].page_content == "text-1"
# text_exists
vector._tablestore_client.get_row.return_value = (None, object(), None)
assert vector.text_exists("id-1") is True
vector._tablestore_client.get_row.return_value = (None, None, None)
assert vector.text_exists("id-1") is False
# delete wrappers
vector._delete_row = MagicMock()
vector.delete_by_ids([])
vector._delete_row.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
assert vector._delete_row.call_count == 2
vector._search_by_metadata = MagicMock(return_value=["id-a"])
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"]
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-a"])
vector._search_by_vector = MagicMock(return_value=["vec-doc"])
vector._search_by_full_text = MagicMock(return_value=["fts-doc"])
assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"]
assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"]
vector._delete_table_if_exist = MagicMock()
vector.delete()
vector._delete_table_if_exist.assert_called_once()
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock())
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1))
vector._create_table_if_not_exist = MagicMock()
vector._create_search_index_if_not_exist = MagicMock()
vector._create_collection(3)
vector._create_table_if_not_exist.assert_not_called()
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(3)
vector._create_table_if_not_exist.assert_called_once()
vector._create_search_index_if_not_exist.assert_called_once_with(3)
tablestore_module.redis_client.set.assert_called_once()
vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module))
vector._tablestore_client.list_table.return_value = ["collection_2"]
assert vector._create_table_if_not_exist() is None
vector._tablestore_client.list_table.return_value = []
vector._create_table_if_not_exist()
vector._tablestore_client.create_table.assert_called_once()
vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")]
assert vector._create_search_index_if_not_exist(3) is None
vector._tablestore_client.list_search_index.return_value = []
vector._create_search_index_if_not_exist(3)
vector._tablestore_client.create_search_index.assert_called_once()
vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")]
vector._delete_table_if_exist()
assert vector._tablestore_client.delete_search_index.call_count == 2
vector._tablestore_client.delete_table.assert_called_once_with("collection_2")
vector._delete_search_index()
vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx")
def test_write_row_and_search_helpers(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
vector._write_row(
"id-1",
{
"page_content": "hello",
"vector": [0.1, 0.2],
"metadata": {"doc_id": "d-1", "document_id": "doc-1"},
},
)
put_row_call = vector._tablestore_client.put_row.call_args
assert put_row_call.args[0] == "collection_1"
attrs = put_row_call.args[1].attribute_columns
assert any(item[0] == "metadata_tags" for item in attrs)
vector._delete_row("id-1")
vector._tablestore_client.delete_row.assert_called_once()
# metadata search pagination
first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next")
second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"")
vector._tablestore_client.search.side_effect = [first_page, second_page]
ids = vector._search_by_metadata("document_id", "doc-1")
assert ids == ["row-1", "row-2"]
vector._tablestore_client.search.side_effect = None
# vector search
hit1 = SimpleNamespace(
score=0.9,
row=(
None,
[("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))],
),
)
hit2 = SimpleNamespace(
score=0.2,
row=(
None,
[("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))],
),
)
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2])
docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0)
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0
# full text search with and without normalized score filter
vector._normalize_full_text_bm25_score = True
hit3 = SimpleNamespace(
score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))])
)
hit4 = SimpleNamespace(
score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))])
)
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4])
docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2)
assert len(docs) == 1
assert "score" in docs[0].metadata
vector._normalize_full_text_bm25_score = False
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3])
docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0)
assert len(docs) == 1
assert "score" not in docs[0].metadata
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
factory = tablestore_module.TableStoreVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True)
with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,309 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_tencent_modules():
tcvdb_text = types.ModuleType("tcvdb_text")
tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder")
tcvectordb = types.ModuleType("tcvectordb")
tcvectordb_model = types.ModuleType("tcvectordb.model")
tcvectordb_document = types.ModuleType("tcvectordb.model.document")
tcvectordb_index = types.ModuleType("tcvectordb.model.index")
tcvectordb_enum = types.ModuleType("tcvectordb.model.enum")
class _BM25Encoder:
def encode_texts(self, text):
return {"encoded_text": text}
def encode_queries(self, query):
return {"encoded_query": query}
@classmethod
def default(cls, _lang):
return cls()
class VectorDBError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
class RPCVectorDBClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_database_if_not_exists = MagicMock()
self.exists_collection = MagicMock(return_value=False)
self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[]))
self.create_collection = MagicMock()
self.upsert = MagicMock()
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.search = MagicMock(return_value=[])
self.hybrid_search = MagicMock(return_value=[])
self.drop_collection = MagicMock()
class _Document:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class _HNSWSearchParams:
def __init__(self, ef):
self.ef = ef
class _AnnSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _KeywordSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _WeightedRerank:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _Filter:
@staticmethod
def in_(field, values):
return ("in", field, values)
def __init__(self, condition):
self.condition = condition
_Filter.In = staticmethod(_Filter.in_)
class _HNSWParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _FilterIndex:
def __init__(self, *args):
self.args = args
class _VectorIndex:
def __init__(self, *args):
self.args = args
class _SparseIndex:
def __init__(self, **kwargs):
self.kwargs = kwargs
tcvectordb_enum.IndexType = SimpleNamespace(
__members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"},
PRIMARY_KEY="PRIMARY_KEY",
FILTER="FILTER",
SPARSE_INVERTED="SPARSE",
)
tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP")
tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector")
tcvectordb_document.Document = _Document
tcvectordb_document.HNSWSearchParams = _HNSWSearchParams
tcvectordb_document.AnnSearch = _AnnSearch
tcvectordb_document.Filter = _Filter
tcvectordb_document.KeywordSearch = _KeywordSearch
tcvectordb_document.WeightedRerank = _WeightedRerank
tcvectordb_index.HNSWParams = _HNSWParams
tcvectordb_index.FilterIndex = _FilterIndex
tcvectordb_index.VectorIndex = _VectorIndex
tcvectordb_index.SparseIndex = _SparseIndex
tcvdb_text_encoder.BM25Encoder = _BM25Encoder
tcvectordb_model.document = tcvectordb_document
tcvectordb_model.enum = tcvectordb_enum
tcvectordb_model.index = tcvectordb_index
tcvectordb.RPCVectorDBClient = RPCVectorDBClient
tcvectordb.VectorDBException = VectorDBError
return {
"tcvdb_text": tcvdb_text,
"tcvdb_text.encoder": tcvdb_text_encoder,
"tcvectordb": tcvectordb,
"tcvectordb.model": tcvectordb_model,
"tcvectordb.model.document": tcvectordb_document,
"tcvectordb.model.index": tcvectordb_index,
"tcvectordb.model.enum": tcvectordb_enum,
}
@pytest.fixture
def tencent_module(monkeypatch):
for name, module in _build_fake_tencent_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.tencent.tencent_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"url": "http://vdb.local",
"api_key": "api-key",
"timeout": 30,
"username": "user",
"database": "db",
"index_type": "HNSW",
"metric_type": "IP",
"shard": 1,
"replicas": 2,
"max_upsert_batch_size": 2,
"enable_hybrid_search": False,
}
values.update(overrides)
return module.TencentConfig.model_validate(values)
def test_config_and_init_paths(tencent_module):
config = _config(tencent_module)
assert config.to_tencent_params()["url"] == "http://vdb.local"
vector = tencent_module.TencentVector("collection_1", config)
assert vector.get_type() == tencent_module.VectorType.TENCENT
assert vector._client.kwargs["key"] == "api-key"
vector._client.exists_collection.return_value = True
vector._client.describe_collection.return_value = SimpleNamespace(
indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)]
)
vector._client_config.enable_hybrid_search = True
vector._load_collection()
assert vector._enable_hybrid_search is True
assert vector._dimension == 768
vector._client.describe_collection.return_value = SimpleNamespace(
indexes=[SimpleNamespace(name="vector", dimension=512)]
)
vector._load_collection()
assert vector._enable_hybrid_search is False
def test_create_collection_branches(tencent_module, monkeypatch):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock())
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None))
vector._client.exists_collection.return_value = True
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
vector._client.exists_collection.return_value = False
vector._client_config.index_type = "UNKNOWN"
with pytest.raises(ValueError, match="unsupported index_type"):
vector._create_collection(3)
vector._client_config.index_type = "HNSW"
vector._client_config.metric_type = "UNKNOWN"
with pytest.raises(ValueError, match="unsupported metric_type"):
vector._create_collection(3)
vector._client_config.metric_type = "IP"
vector._client.create_collection.side_effect = [
tencent_module.VectorDBException("fieldType:json unsupported"),
None,
]
vector._enable_hybrid_search = True
vector._create_collection(3)
assert vector._client.create_collection.call_count == 2
tencent_module.redis_client.set.assert_called_once()
vector._client.create_collection.side_effect = None
def test_create_add_delete_and_search_behaviour(tencent_module):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True))
vector._create_collection = MagicMock()
docs = [
Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}),
Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}),
Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}),
]
embeddings = [[0.1], [0.2], [0.3]]
vector.create(docs, embeddings)
vector._create_collection.assert_called_once_with(1)
vector._client.upsert.reset_mock()
vector.add_texts(docs, embeddings)
assert vector._client.upsert.call_count == 2
first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"]
assert "sparse_vector" in first_docs[0].__dict__
vector._client.query.return_value = [{"id": "a"}]
assert vector.text_exists("a") is True
vector._client.query.return_value = []
assert vector.text_exists("a") is False
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["a", "b", "c"])
assert vector._client.delete.call_count == 2
vector.delete_by_metadata_field("document_id", "doc-a")
assert vector._client.delete.call_count >= 3
vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]]
vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
assert len(vec_docs) == 1
assert vec_docs[0].metadata["score"] == pytest.approx(0.9)
vector._enable_hybrid_search = False
assert vector.search_by_full_text("query") == []
vector._enable_hybrid_search = True
vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]]
fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
assert len(fts_docs) == 1
# _get_search_res handles old string metadata format
compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5)
assert len(compat_docs) == 1
assert compat_docs[0].metadata["score"] == pytest.approx(0.8)
vector._has_collection = MagicMock(return_value=True)
vector.delete()
vector._client.drop_collection.assert_called_once()
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
factory = tencent_module.TencentVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True)
with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,88 @@
from types import SimpleNamespace
import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
class _DummyVector(BaseVector):
def __init__(self, collection_name: str, existing_ids: set[str] | None = None):
super().__init__(collection_name)
self._existing_ids = existing_ids or set()
def get_type(self) -> str:
return "dummy"
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return None
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
return None
def text_exists(self, id: str) -> bool:
return id in self._existing_ids
def delete_by_ids(self, ids: list[str]):
return None
def delete_by_metadata_field(self, key: str, value: str):
return None
def search_by_vector(self, query_vector: list[float], **kwargs):
return []
def search_by_full_text(self, query: str, **kwargs):
return []
def delete(self):
return None
@pytest.mark.parametrize(
("base_method", "args"),
[
(BaseVector.get_type, ()),
(BaseVector.create, ([], [])),
(BaseVector.add_texts, ([], [])),
(BaseVector.text_exists, ("doc-1",)),
(BaseVector.delete_by_ids, ([],)),
(BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")),
(BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")),
(BaseVector.search_by_vector, ([0.1],)),
(BaseVector.search_by_full_text, ("query",)),
(BaseVector.delete, ()),
],
)
def test_base_vector_default_methods_raise_not_implemented(base_method, args):
vector = _DummyVector("collection_1")
with pytest.raises(NotImplementedError):
base_method(vector, *args)
def test_filter_duplicate_texts_removes_existing_docs():
vector = _DummyVector("collection_1", existing_ids={"dup"})
docs = [
SimpleNamespace(page_content="keep-no-meta", metadata=None),
Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}),
Document(page_content="remove-dup", metadata={"doc_id": "dup"}),
Document(page_content="keep-unique", metadata={"doc_id": "unique"}),
]
filtered = vector._filter_duplicate_texts(docs)
assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"]
def test_get_uuids_and_collection_name_property():
vector = _DummyVector("collection_1")
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
SimpleNamespace(page_content="b", metadata=None),
Document(page_content="c", metadata={"document_id": "d-1"}),
Document(page_content="d", metadata={"doc_id": "id-2"}),
]
assert vector._get_uuids(docs) == ["id-1", "id-2"]
assert vector.collection_name == "collection_1"

View File

@ -0,0 +1,434 @@
import base64
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str):
fake_module = types.ModuleType(module_path)
fake_cls = type(class_name, (), {})
setattr(fake_module, class_name, fake_cls)
monkeypatch.setitem(sys.modules, module_path, fake_module)
return fake_cls
@pytest.fixture
def vector_factory_module():
import importlib
import core.rag.datasource.vdb.vector_factory as module
return importlib.reload(module)
def test_gen_index_struct_dict(vector_factory_module):
result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict(
vector_factory_module.VectorType.WEAVIATE,
"collection_1",
)
assert result == {
"type": vector_factory_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": "collection_1"},
}
@pytest.mark.parametrize(
("vector_type", "module_path", "class_name"),
[
("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"),
("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"),
(
"ALIBABACLOUD_MYSQL",
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector",
"AlibabaCloudMySQLVectorFactory",
),
("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"),
("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"),
("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"),
("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"),
("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"),
("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"),
(
"ELASTICSEARCH",
"core.rag.datasource.vdb.elasticsearch.elasticsearch_vector",
"ElasticSearchVectorFactory",
),
(
"ELASTICSEARCH_JA",
"core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector",
"ElasticSearchJaVectorFactory",
),
("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"),
("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"),
("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"),
("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"),
(
"OPENSEARCH",
"core.rag.datasource.vdb.opensearch.opensearch_vector",
"OpenSearchVectorFactory",
),
("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"),
("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"),
("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"),
("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"),
("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"),
(
"TIDB_ON_QDRANT",
"core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector",
"TidbOnQdrantVectorFactory",
),
("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"),
("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"),
("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"),
(
"HUAWEI_CLOUD",
"core.rag.datasource.vdb.huawei.huawei_cloud_vector",
"HuaweiCloudVectorFactory",
),
("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"),
("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"),
("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"),
],
)
def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name):
expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name)
result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type))
assert result_cls is expected_cls
def test_get_vector_factory_unsupported(vector_factory_module):
with pytest.raises(ValueError, match="not supported"):
vector_factory_module.Vector.get_vector_factory("unknown")
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
dataset = SimpleNamespace(id="dataset-1")
with (
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
):
default_vector = vector_factory_module.Vector(dataset)
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
assert custom_vector._attributes == ["doc_id"]
assert default_vector._embeddings == "embeddings"
assert default_vector._vector_processor == "processor"
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
calls = {"vector_type": None, "init_args": None}
class _Factory:
def init_vector(self, dataset, attributes, embeddings):
calls["init_args"] = (dataset, attributes, embeddings)
return "vector-processor"
monkeypatch.setattr(
vector_factory_module.Vector,
"get_vector_factory",
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(
index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1"
)
vector._attributes = ["doc_id"]
vector._embeddings = "embeddings"
result = vector._init_vector()
assert result == "vector-processor"
assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH
assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings")
def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch):
class _Expr:
def __eq__(self, _other):
return "expr"
calls = {"vector_type": None}
class _Factory:
def init_vector(self, dataset, attributes, embeddings):
return "vector-processor"
monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr()))
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
monkeypatch.setattr(
vector_factory_module,
"db",
SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))),
)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True)
monkeypatch.setattr(
vector_factory_module.Vector,
"get_vector_factory",
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
vector._attributes = ["doc_id"]
vector._embeddings = "embeddings"
result = vector._init_vector()
assert result == "vector-processor"
assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT
def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch):
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
vector._attributes = []
vector._embeddings = "embeddings"
with pytest.raises(ValueError, match="Vector store must be specified"):
vector._init_vector()
def test_create_batches_texts_and_skips_empty_input(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)]
vector._embeddings.embed_documents.side_effect = [
[[0.1] for _ in range(1000)],
[[0.2]],
]
vector.create(texts=docs, trace_id="trace-1")
assert vector._embeddings.embed_documents.call_count == 2
assert vector._vector_processor.create.call_count == 2
assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1"
vector._embeddings.embed_documents.reset_mock()
vector._vector_processor.create.reset_mock()
vector.create(texts=None)
vector._embeddings.embed_documents.assert_not_called()
vector._vector_processor.create.assert_not_called()
def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch):
class _Field:
def in_(self, value):
return value
def __eq__(self, value):
return value
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]]
vector._vector_processor = MagicMock()
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
monkeypatch.setattr(
vector_factory_module,
"db",
SimpleNamespace(
session=SimpleNamespace(
scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")])
)
),
)
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc"))
docs = [
Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}),
Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}),
]
vector.create_multimodal(file_documents=docs, request_id="r-1")
file_base64 = base64.b64encode(b"abc").decode()
vector._embeddings.embed_multimodal_documents.assert_called_once_with(
[{"content": file_base64, "content_type": "image", "file_id": "f-1"}]
)
vector._vector_processor.create.assert_called_once_with(
texts=[docs[0]],
embeddings=[[0.1, 0.2]],
request_id="r-1",
)
vector._embeddings.embed_multimodal_documents.reset_mock()
vector._vector_processor.create.reset_mock()
vector.create_multimodal(file_documents=None)
vector._embeddings.embed_multimodal_documents.assert_not_called()
vector._vector_processor.create.assert_not_called()
def test_add_texts_with_optional_duplicate_check(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
vector._filter_duplicate_texts = MagicMock()
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
vector._filter_duplicate_texts.return_value = [docs[0]]
vector._embeddings.embed_documents.return_value = [[0.1]]
vector.add_texts(docs, duplicate_check=True, flag=True)
vector._filter_duplicate_texts.assert_called_once_with(docs)
vector._vector_processor.create.assert_called_once_with(
texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True
)
vector._filter_duplicate_texts.reset_mock()
vector._vector_processor.create.reset_mock()
vector._embeddings.embed_documents.return_value = [[0.2], [0.3]]
vector.add_texts(docs, duplicate_check=False)
vector._filter_duplicate_texts.assert_not_called()
vector._vector_processor.create.assert_called_once()
def test_vector_delegation_methods(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._embeddings.embed_query.return_value = [0.1, 0.2]
vector._vector_processor = MagicMock()
vector._vector_processor.text_exists.return_value = True
vector._vector_processor.search_by_vector.return_value = ["vector-doc"]
vector._vector_processor.search_by_full_text.return_value = ["text-doc"]
assert vector.text_exists("doc-1") is True
vector.delete_by_ids(["doc-1"])
vector.delete_by_metadata_field("doc_id", "doc-1")
assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"]
assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"]
vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"])
vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1")
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
class _Field:
def __eq__(self, value):
return value
upload_query = MagicMock()
upload_query.where.return_value = upload_query
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
)
upload_query.first.return_value = None
assert vector.search_by_file("file-1") == []
upload_query.first.return_value = SimpleNamespace(key="blob-key")
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
vector._vector_processor.search_by_vector.return_value = ["hit"]
result = vector.search_by_file("file-2", top_k=2)
assert result == ["hit"]
payload = vector._embeddings.embed_multimodal_query.call_args.args[0]
assert payload["content_type"] == vector_factory_module.DocType.IMAGE
assert payload["file_id"] == "file-2"
def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch):
delete_mock = MagicMock()
redis_delete = MagicMock()
monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1")
vector.delete()
delete_mock.assert_called_once()
redis_delete.assert_called_once_with("vector_indexing_collection_1")
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="")
redis_delete.reset_mock()
vector.delete()
redis_delete.assert_not_called()
def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch):
model_manager = MagicMock()
model_manager.get_model_instance.return_value = "model-instance"
monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager))
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding"))
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(
tenant_id="tenant-1",
embedding_model_provider="openai",
embedding_model="text-embedding-3-small",
)
result = vector._get_embeddings()
assert result == "cached-embedding"
model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="openai",
model_type=vector_factory_module.ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
def test_filter_duplicate_texts_and_getattr(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup")
docs = [
SimpleNamespace(page_content="no-meta", metadata=None),
Document(page_content="empty-doc-id", metadata={"doc_id": ""}),
Document(page_content="duplicate", metadata={"doc_id": "dup"}),
Document(page_content="unique", metadata={"doc_id": "ok"}),
]
filtered = vector._filter_duplicate_texts(docs)
assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"]
class _Processor:
def ping(self):
return "pong"
vector._vector_processor = _Processor()
assert vector.ping() == "pong"
with pytest.raises(AttributeError):
_ = vector.unknown_method
vector._vector_processor = None
with pytest.raises(AttributeError, match="vector_processor"):
_ = vector.another_missing

View File

@ -0,0 +1,443 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
@pytest.fixture
def tidb_module():
import core.rag.datasource.vdb.tidb_vector.tidb_vector as module
return importlib.reload(module)
def _config(tidb_module):
return tidb_module.TiDBVectorConfig(
host="localhost",
port=4000,
user="root",
password="secret",
database="dify",
program_name="dify-app",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config TIDB_VECTOR_HOST is required"),
("port", 0, "config TIDB_VECTOR_PORT is required"),
("user", "", "config TIDB_VECTOR_USER is required"),
("database", "", "config TIDB_VECTOR_DATABASE is required"),
("program_name", "", "config APPLICATION_NAME is required"),
],
)
def test_tidb_config_validation(tidb_module, field, value, message):
values = _config(tidb_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
tidb_module.TiDBVectorConfig.model_validate(values)
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR
assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify")
assert vector._dimension == 1536
assert vector._get_distance_func() == "VEC_L2_DISTANCE"
vector._distance_func = "cosine"
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
vector._distance_func = "other"
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
fake_tidb_vector = types.ModuleType("tidb_vector")
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
class _VectorType:
def __init__(self, dim):
self.dim = dim
fake_tidb_sqlalchemy.VectorType = _VectorType
monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector)
monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy)
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock()))
monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(
tidb_module,
"Table",
lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns),
)
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module))
table = vector._table(3)
assert table.name == "collection_1"
column_names = [column.args[0] for column in table.columns]
assert tidb_module.Field.PRIMARY_KEY in column_names
assert tidb_module.Field.VECTOR in column_names
assert tidb_module.Field.TEXT_KEY in column_names
def test_create_calls_collection_and_add_texts(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
assert vector._dimension == 2
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
tidb_module.Session = MagicMock()
vector._create_collection(3)
tidb_module.Session.assert_not_called()
tidb_module.redis_client.set.assert_not_called()
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
session = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "l2"
vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once()
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
class _InsertStmt:
def __init__(self, table):
self.table = table
def values(self, rows):
return {"table": self.table, "rows": rows}
monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table))
conn = MagicMock()
transaction = MagicMock()
transaction.__enter__.return_value = None
transaction.__exit__.return_value = None
conn.begin.return_value = transaction
connection_ctx = MagicMock()
connection_ctx.__enter__.return_value = conn
connection_ctx.__exit__.return_value = None
engine = MagicMock()
engine.connect.return_value = connection_ctx
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._engine = engine
vector._table = MagicMock(return_value="table")
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)]
embeddings = [[float(i)] for i in range(501)]
ids = vector.add_texts(docs, embeddings)
assert ids[0] == "id-0"
assert len(ids) == 501
assert conn.execute.call_count == 2
@pytest.fixture
def tidb_vector_with_session(tidb_module, monkeypatch):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
session = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
return vector, session, tidb_module
# 1. search_by_full_text returns empty
def test_search_by_full_text_returns_empty(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
assert vector.search_by_full_text("query") == []
# 2. text_exists returns True when ids found
def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
assert vector.text_exists("doc-1") is True
# 3. text_exists returns False when no ids
def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
vector.get_ids_by_metadata_field = MagicMock(return_value=None)
assert vector.text_exists("doc-1") is False
# 4. delete_by_ids delegates to _delete_by_ids when ids found
def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)]
vector._delete_by_ids = MagicMock()
# Use real get_ids_by_metadata_field
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
vector, tidb_module.TiDBVector
)
vector.delete_by_ids(["doc-a", "doc-b"])
vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"])
# 5. delete_by_ids skips when no ids found
def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
session.execute.return_value.fetchall.return_value = []
vector._delete_by_ids = MagicMock()
# Use real get_ids_by_metadata_field
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
vector, tidb_module.TiDBVector
)
vector.delete_by_ids(["doc-c"])
vector._delete_by_ids.assert_not_called()
# 6. get_ids_by_metadata_field returns ids and returns None
def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
# Returns ids
session.execute.return_value.fetchall.return_value = [("id-1",)]
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"]
# Returns None
session.execute.return_value.fetchall.return_value = []
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None
# 1. _delete_by_ids raises on None
def test__delete_by_ids_raises_on_none(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
with pytest.raises(ValueError, match="No ids provided"):
vector._delete_by_ids(None)
# 2. _delete_by_ids returns True and calls execute
def test__delete_by_ids_returns_true_and_calls_execute(tidb_module):
class _IDColumn:
def in_(self, ids):
return ids
class _Delete:
def where(self, condition):
return condition
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = None
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn_ctx = MagicMock()
conn_ctx.__enter__.return_value = conn
conn_ctx.__exit__.return_value = None
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._dimension = 2
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
vector._table = MagicMock(return_value=table)
assert vector._delete_by_ids(["id-1"]) is True
conn.execute.assert_called_once()
# 3. _delete_by_ids returns False on RuntimeError
def test__delete_by_ids_returns_false_on_runtime_error(tidb_module):
class _IDColumn:
def in_(self, ids):
return ids
class _Delete:
def where(self, condition):
return condition
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = None
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn_ctx = MagicMock()
conn_ctx.__enter__.return_value = conn
conn_ctx.__exit__.return_value = None
conn.execute.side_effect = RuntimeError("delete failed")
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._dimension = 2
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
vector._table = MagicMock(return_value=table)
assert vector._delete_by_ids(["id-2"]) is False
# 4. delete_by_metadata_field calls _delete_by_ids when ids found
def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"])
vector._delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-3")
vector._delete_by_ids.assert_called_once_with(["id-3"])
# 5. delete_by_metadata_field does nothing when no ids
def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
vector._delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-4")
vector._delete_by_ids.assert_not_called()
# Test search_by_vector filters and scores
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = [
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4),
]
session.commit = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "cosine"
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 2
assert docs[0].metadata["score"] == pytest.approx(0.8)
assert docs[1].metadata["score"] == pytest.approx(0.6)
sql = str(session.execute.call_args.args[0])
params = session.execute.call_args.kwargs["params"]
assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql
assert params["distance"] == pytest.approx(0.5)
assert params["top_k"] == 2
session.commit.assert_not_called()
# Test delete drops table
def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector.delete()
drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
factory = tidb_module.TiDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000)
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify")
monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app")
with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,186 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_upstash_module():
upstash_module = types.ModuleType("upstash_vector")
class Vector:
def __init__(self, id, vector, metadata, data):
self.id = id
self.vector = vector
self.metadata = metadata
self.data = data
class Index:
def __init__(self, url, token):
self.url = url
self.token = token
self.info = MagicMock(return_value=SimpleNamespace(dimension=8))
self.upsert = MagicMock()
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.reset = MagicMock()
upstash_module.Vector = Vector
upstash_module.Index = Index
return upstash_module
@pytest.fixture
def upstash_module(monkeypatch):
# Remove patched modules if present
for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]:
if modname in sys.modules:
monkeypatch.delitem(sys.modules, modname, raising=False)
monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module())
module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector")
return module
def _config(module):
return module.UpstashVectorConfig(url="https://upstash.example", token="token-123")
@pytest.mark.parametrize(
("field", "value", "message"),
[
("url", "", "Upstash URL is required"),
("token", "", "Upstash Token is required"),
],
)
def test_upstash_config_validation(upstash_module, field, value, message):
values = _config(upstash_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
upstash_module.UpstashVectorConfig.model_validate(values)
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
assert vector.get_type() == upstash_module.VectorType.UPSTASH
assert vector._table_name == "collection_1"
assert vector._get_index_dimension() == 8
vector.index.info.return_value = SimpleNamespace(dimension=None)
assert vector._get_index_dimension() == 1536
vector.index.info.return_value = None
assert vector._get_index_dimension() == 1536
monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid")
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.add_texts(docs, [[0.1, 0.2]])
vector.index.upsert.assert_called_once()
upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"]
assert upsert_vectors[0].id == "generated-uuid"
def test_create_text_exists_and_delete_by_ids(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1]])
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
assert vector.text_exists("doc-1") is True
vector.get_ids_by_metadata_field.return_value = []
assert vector.text_exists("doc-1") is False
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]])
vector._delete_by_ids = MagicMock()
vector.delete_by_ids(["doc-1", "doc-2", "doc-3"])
vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"])
def test_delete_helpers_and_search(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector._delete_by_ids([])
vector.index.delete.assert_not_called()
vector._delete_by_ids(["a", "b"])
vector.index.delete.assert_called_once_with(ids=["a", "b"])
vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")]
ids = vector.get_ids_by_metadata_field("doc_id", "doc-1")
assert ids == ["x-1", "x-2"]
query_kwargs = vector.index.query.call_args.kwargs
assert query_kwargs["top_k"] == 1000
assert query_kwargs["filter"] == "doc_id = 'doc-1'"
vector._delete_by_ids = MagicMock()
vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"])
vector.delete_by_metadata_field("doc_id", "doc-1")
vector._delete_by_ids.assert_called_once_with(["x-1"])
vector._delete_by_ids.reset_mock()
vector.get_ids_by_metadata_field.return_value = []
vector.delete_by_metadata_field("doc_id", "doc-2")
vector._delete_by_ids.assert_not_called()
def test_search_by_vector_filter_threshold_and_delete(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector.index.query.return_value = [
SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9),
SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3),
SimpleNamespace(metadata=None, data="text-3", score=0.99),
SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99),
]
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=3,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].page_content == "text-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
search_kwargs = vector.index.query.call_args.kwargs
assert search_kwargs["top_k"] == 3
assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')"
assert vector.search_by_full_text("query") == []
vector.delete()
vector.index.reset.assert_called_once()
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
factory = upstash_module.UpstashVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example")
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123")
with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,310 @@
import importlib
import json
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_vikingdb_modules():
volcengine = types.ModuleType("volcengine")
volcengine.__path__ = []
viking_db = types.ModuleType("volcengine.viking_db")
class Data(UserDict):
def __init__(self, payload):
super().__init__(payload)
self.fields = payload
class DistanceType:
L2 = "L2"
class IndexType:
HNSW = "HNSW"
class QuantType:
Float = "Float"
class FieldType:
String = "string"
Text = "text"
Vector = "vector"
class Field:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorIndexParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _Collection:
def __init__(self):
self.upsert_data = MagicMock()
self.fetch_data = MagicMock(return_value=None)
self.delete_data = MagicMock()
class _Index:
def __init__(self):
self.search = MagicMock(return_value=[])
self.search_by_vector = MagicMock(return_value=[])
class VikingDBService:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_collection = MagicMock()
self.create_index = MagicMock()
self.drop_index = MagicMock()
self.drop_collection = MagicMock()
self._collection = _Collection()
self._index = _Index()
self.get_collection = MagicMock(return_value=self._collection)
self.get_index = MagicMock(return_value=self._index)
viking_db.Data = Data
viking_db.DistanceType = DistanceType
viking_db.Field = Field
viking_db.FieldType = FieldType
viking_db.IndexType = IndexType
viking_db.QuantType = QuantType
viking_db.VectorIndexParams = VectorIndexParams
viking_db.VikingDBService = VikingDBService
return {"volcengine": volcengine, "volcengine.viking_db": viking_db}
@pytest.fixture
def vikingdb_module(monkeypatch):
for name, module in _build_fake_vikingdb_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module
return importlib.reload(module)
def _config(module):
return module.VikingDBConfig(
access_key="ak",
secret_key="sk",
host="host",
region="region",
scheme="https",
connection_timeout=10,
socket_timeout=20,
)
def test_init_get_type_and_has_checks(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB
assert vector._index_name == "collection_1_idx"
assert vector._has_collection() is True
assert vector._has_index() is True
vector._client.get_collection.side_effect = RuntimeError("missing")
assert vector._has_collection() is False
vector._client.get_collection.side_effect = None
vector._client.get_index.side_effect = RuntimeError("missing")
assert vector._has_index() is False
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock())
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
vector._client.create_index.assert_not_called()
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None))
vector._has_collection = MagicMock(return_value=False)
vector._has_index = MagicMock(return_value=False)
vector._create_collection(4)
vector._client.create_collection.assert_called_once()
vector._client.create_index.assert_called_once()
vikingdb_module.redis_client.set.assert_called_once()
def test_create_and_add_texts(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module))
docs = [
Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}),
Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}),
]
vector.add_texts(docs, [[0.1], [0.2]])
vector._client.get_collection.assert_called()
upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0]
assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a"
assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2"
def test_text_exists_and_delete_operations(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"})
assert vector.text_exists("id-1") is True
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(
fields={"message": "data does not exist"}
)
assert vector.text_exists("id-1") is False
vector._client.get_collection.return_value.fetch_data.return_value = None
assert vector.text_exists("id-1") is False
vector.delete_by_ids(["id-1"])
vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"])
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-2"])
def test_get_ids_and_search_helpers(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._client.get_index.return_value.search.return_value = []
assert vector.get_ids_by_metadata_field("doc_id", "x") == []
vector._client.get_index.return_value.search.return_value = [
SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}),
SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}),
SimpleNamespace(id="c", fields={}),
]
assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"]
empty_docs = vector._get_search_res([], score_threshold=0.1)
assert empty_docs == []
results = [
SimpleNamespace(
id="a",
score=0.3,
fields={
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a",
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}),
},
),
SimpleNamespace(
id="b",
score=0.9,
fields={
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b",
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}),
},
),
]
docs = vector._get_search_res(results, score_threshold=0.2)
assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"]
vector._client.get_index.return_value.search_by_vector.return_value = results
filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"])
assert len(filtered_docs) == 1
assert filtered_docs[0].page_content == "doc-b"
assert vector.search_by_full_text("query") == []
def test_delete_drops_index_and_collection_when_present(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._has_index = MagicMock(return_value=True)
vector._has_collection = MagicMock(return_value=True)
vector.delete()
vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx")
vector._client.drop_collection.assert_called_once_with("collection_1")
vector._client.drop_index.reset_mock()
vector._client.drop_collection.reset_mock()
vector._has_index.return_value = False
vector._has_collection.return_value = False
vector.delete()
vector._client.drop_index.assert_not_called()
vector._client.drop_collection.assert_not_called()
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
factory = vikingdb_module.VikingDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls:
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10)
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20)
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None
@pytest.mark.parametrize(
("field", "message"),
[
("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"),
("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"),
("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"),
("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"),
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
],
)
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
factory = vikingdb_module.VikingDBVectorFactory()
dataset = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None
)
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
monkeypatch.setattr(vikingdb_module.dify_config, field, None)
with pytest.raises(ValueError, match=message):
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())

View File

@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in:
- Full-text search result metadata (search_by_full_text)
"""
import datetime
import json
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from core.rag.models.document import Document
@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase):
def tearDown(self):
weaviate_vector_module._weaviate_client = None
def test_config_requires_endpoint(self):
with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"):
WeaviateConfig(endpoint="")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def _create_weaviate_vector(self, mock_weaviate_module):
"""Helper to create a WeaviateVector instance with mocked client."""
@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase):
)
return wv, mock_client
def test_shutdown_client_logs_debug_when_close_fails(self):
mock_client = MagicMock()
mock_client.close.side_effect = RuntimeError("close failed")
weaviate_vector_module._weaviate_client = mock_client
with patch.object(weaviate_vector_module.logger, "debug") as mock_debug:
weaviate_vector_module._shutdown_weaviate_client()
assert weaviate_vector_module._weaviate_client is None
mock_client.close.assert_called_once()
mock_debug.assert_called_once()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect):
cached_client = MagicMock()
cached_client.is_ready.return_value = True
weaviate_vector_module._weaviate_client = cached_client
wv = WeaviateVector.__new__(WeaviateVector)
client = wv._init_client(self.config)
assert client is cached_client
mock_connect.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect):
cached_client = MagicMock()
cached_client.is_ready.side_effect = [False, True]
weaviate_vector_module._weaviate_client = cached_client
wv = WeaviateVector.__new__(WeaviateVector)
client = wv._init_client(self.config)
assert client is cached_client
mock_connect.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_connect.return_value = mock_client
wv = WeaviateVector.__new__(WeaviateVector)
config = WeaviateConfig(
endpoint="https://weaviate.example.com",
grpc_endpoint="grpc.example.com:6000",
api_key="test-key",
batch_size=50,
)
client = wv._init_client(config)
assert client is mock_client
assert mock_connect.call_args.kwargs == {
"http_host": "weaviate.example.com",
"http_port": 443,
"http_secure": True,
"grpc_host": "grpc.example.com",
"grpc_port": 6000,
"grpc_secure": False,
"auth_credentials": "auth-token",
"skip_init_checks": True,
}
mock_api_key.assert_called_once_with("test-key")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_raises_when_database_not_ready(self, mock_connect):
mock_client = MagicMock()
mock_client.is_ready.return_value = False
mock_connect.return_value = mock_client
wv = WeaviateVector.__new__(WeaviateVector)
with pytest.raises(ConnectionError, match="Vector database is not ready"):
wv._init_client(self.config)
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_init(self, mock_weaviate_module):
"""Test WeaviateVector initialization stores attributes including doc_type."""
@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase):
assert wv._collection_name == self.collection_name
assert "doc_type" in wv._attributes
def test_get_type_and_to_index_struct(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE
assert wv.to_index_struct() == {
"type": weaviate_vector_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": self.collection_name},
}
def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self):
dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1")
wv = WeaviateVector.__new__(WeaviateVector)
assert wv.get_collection_name(dataset) == "ExistingCollection_Node"
def test_get_collection_name_generates_name_from_dataset_id(self):
dataset = SimpleNamespace(index_struct_dict=None, id="ds-2")
wv = WeaviateVector.__new__(WeaviateVector)
with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"):
assert wv.get_collection_name(dataset) == "Generated_Node"
def test_create_calls_collection_setup_then_add_texts(self):
doc = Document(page_content="hello", metadata={})
wv = WeaviateVector.__new__(WeaviateVector)
wv._create_collection = MagicMock()
wv.add_texts = MagicMock()
wv.create([doc], [[0.1, 0.2]])
wv._create_collection.assert_called_once()
wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]])
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase):
f"doc_type should be in collection schema properties, got: {property_names}"
)
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis):
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._ensure_properties = MagicMock()
wv._create_collection()
wv._client.collections.exists.assert_not_called()
wv._ensure_properties.assert_not_called()
mock_redis.set.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
def test_create_collection_logs_and_reraises_errors(self, mock_redis):
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock(return_value=False)
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = RuntimeError("create failed")
with patch.object(weaviate_vector_module.logger, "exception") as mock_exception:
with pytest.raises(RuntimeError, match="create failed"):
wv._create_collection()
mock_exception.assert_called_once()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties adds doc_type when it's missing from existing schema."""
@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase):
added_names = [call.args[0].name for call in add_calls]
assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_cfg = MagicMock()
mock_cfg.properties = [SimpleNamespace(name="text")]
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
wv._ensure_properties()
add_calls = mock_col.config.add_property.call_args_list
added_names = [call.args[0].name for call in add_calls]
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties does not add doc_type when it already exists."""
@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase):
# No properties should be added
mock_col.config.add_property.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_cfg = MagicMock()
mock_cfg.properties = []
mock_col.config.get.return_value = mock_cfg
mock_col.config.add_property.side_effect = RuntimeError("cannot add")
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
wv._ensure_properties()
assert mock_warning.call_count == 4
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_vector returns doc_type in document metadata.
@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase):
assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_obj = MagicMock()
mock_obj.properties = {
"text": "fallback distance result",
"document_id": "doc-1",
"doc_id": "segment-1",
}
mock_obj.metadata = None
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.near_vector.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_vector(
query_vector=[0.2] * 3,
document_ids_filter=["doc-1"],
top_k=2,
score_threshold=-1,
)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.0
assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = False
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_full_text also returns doc_type in document metadata."""
@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase):
assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_obj = MagicMock()
mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"}
mock_obj.vector = [0.3, 0.4]
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.bm25.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].vector == [0.3, 0.4]
assert mock_col.query.bm25.call_args.kwargs["filters"] is not None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = False
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
assert wv.search_by_full_text(query="missing") == []
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
"""Test that add_texts includes doc_type from document metadata in stored properties."""
@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase):
stored_props = call_kwargs.kwargs.get("properties")
assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_batch = MagicMock()
mock_batch.__enter__ = MagicMock(return_value=mock_batch)
mock_batch.__exit__ = MagicMock(return_value=False)
mock_col.batch.dynamic.return_value = mock_batch
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
doc = Document(page_content="text", metadata={"created_at": created_at})
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
with (
patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]),
patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"),
):
ids = wv.add_texts(documents=[doc], embeddings=[[]])
assert ids == ["fallback-uuid"]
call_kwargs = mock_batch.add_object.call_args
assert call_kwargs.kwargs["uuid"] == "fallback-uuid"
assert call_kwargs.kwargs["vector"] is None
assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat()
def test_is_uuid_handles_invalid_values(self):
wv = WeaviateVector.__new__(WeaviateVector)
assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True
assert wv._is_uuid("not-a-uuid") is False
def test_delete_by_metadata_field_returns_when_collection_is_missing(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = False
wv.delete_by_metadata_field("doc_id", "segment-1")
wv._client.collections.use.assert_not_called()
def test_delete_by_metadata_field_deletes_matching_objects(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
wv.delete_by_metadata_field("doc_id", "segment-1")
mock_col.data.delete_many.assert_called_once()
def test_delete_removes_collection_when_present(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
wv.delete()
wv._client.collections.delete.assert_called_once_with(self.collection_name)
def test_text_exists_handles_missing_and_present_documents(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = [False, True]
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()])
assert wv.text_exists("segment-1") is False
assert wv.text_exists("segment-1") is True
def test_delete_by_ids_handles_missing_collections_and_404s(self):
class FakeUnexpectedStatusCodeError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = [False, True]
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None]
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
wv.delete_by_ids(["ignored"])
wv.delete_by_ids(["missing-id", "ok-id"])
assert mock_col.data.delete_by_id.call_count == 2
def test_delete_by_ids_reraises_non_404_errors(self):
class FakeUnexpectedStatusCodeError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500)
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"):
wv.delete_by_ids(["bad-id"])
def test_json_serializable_converts_datetime(self):
wv = WeaviateVector.__new__(WeaviateVector)
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
assert wv._json_serializable(created_at) == created_at.isoformat()
assert wv._json_serializable("plain") == "plain"
class TestVectorDefaultAttributes(unittest.TestCase):
"""Tests for Vector class default attributes list."""
@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase):
assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}"
class TestWeaviateVectorFactory(unittest.TestCase):
def test_init_vector_uses_existing_dataset_index_struct(self):
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}},
index_struct=None,
)
attributes = ["doc_id"]
with (
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88),
patch(
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
) as mock_vector,
):
factory = weaviate_vector_module.WeaviateVectorFactory()
result = factory.init_vector(dataset, attributes, MagicMock())
assert result == "vector"
config = mock_vector.call_args.kwargs["config"]
assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node"
assert mock_vector.call_args.kwargs["attributes"] == attributes
assert config.endpoint == "http://localhost:8080"
assert config.grpc_endpoint == "localhost:50051"
assert config.api_key == "api-key"
assert config.batch_size == 88
assert dataset.index_struct is None
def test_init_vector_generates_collection_and_updates_index_struct(self):
dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
attributes = ["doc_id", "doc_type"]
with (
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100),
patch.object(
weaviate_vector_module.Dataset,
"gen_collection_name_by_id",
return_value="GeneratedCollection_Node",
),
patch(
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
) as mock_vector,
):
factory = weaviate_vector_module.WeaviateVectorFactory()
result = factory.init_vector(dataset, attributes, MagicMock())
assert result == "vector"
assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node"
assert json.loads(dataset.index_struct) == {
"type": weaviate_vector_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": "GeneratedCollection_Node"},
}
if __name__ == "__main__":
unittest.main()

View File

@ -1164,7 +1164,7 @@ class TestConversationStatusCount:
conversation.id = str(uuid4())
# Mock the database query to return no messages
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
@ -1189,7 +1189,7 @@ class TestConversationStatusCount:
conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
@ -1274,7 +1274,7 @@ class TestConversationStatusCount:
return mock_result
# Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Verify only 2 database queries were made (not N+1)
@ -1337,7 +1337,7 @@ class TestConversationStatusCount:
return mock_result
# Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Assert - query should include app_id filter
@ -1382,7 +1382,7 @@ class TestConversationStatusCount:
),
]
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
with patch("models.model.db.session.scalars") as mock_scalars:
# Mock the messages query
def mock_scalars_side_effect(query):
mock_result = MagicMock()
@ -1438,7 +1438,7 @@ class TestConversationStatusCount:
),
]
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
with patch("models.model.db.session.scalars") as mock_scalars:
def mock_scalars_side_effect(query):
mock_result = MagicMock()

View File

@ -13,6 +13,10 @@ import pytest
from services.plugin.oauth_service import OAuthProxyService
def _oauth_proxy_setex_calls(redis_client) -> list:
return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")]
class TestCreateProxyContext:
def test_stores_context_in_redis_with_ttl(self):
context_id = OAuthProxyService.create_proxy_context(
@ -22,8 +26,9 @@ class TestCreateProxyContext:
assert context_id # non-empty UUID string
from extensions.ext_redis import redis_client
redis_client.setex.assert_called_once()
call_args = redis_client.setex.call_args
oauth_calls = _oauth_proxy_setex_calls(redis_client)
assert len(oauth_calls) == 1
call_args = oauth_calls[0]
key = call_args[0][0]
ttl = call_args[0][1]
stored_data = json.loads(call_args[0][2])

View File

@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch
def test_import_app_pending_stores_import_info_in_redis():
service = AppDslService(MagicMock())
app_dsl_service.redis_client.setex.reset_mock()
result = service.import_app(
account=_account_mock(),
import_mode=ImportMode.YAML_CONTENT,
@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch):
created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
app_dsl_service.redis_client.delete.reset_mock()
result = service.confirm_import(import_id="import-1", account=_account_mock())
assert result.status == ImportStatus.COMPLETED
assert result.app_id == "confirmed-app"
app_dsl_service.redis_client.delete.assert_called_once()
app_dsl_service.redis_client.delete.assert_called_once_with(
f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1"
)
def test_confirm_import_exception_returns_failed(monkeypatch):

View File

@ -405,7 +405,7 @@ class TestAudioServiceTTS:
voice="en-US-Neural",
)
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID."""
@ -549,7 +549,7 @@ class TestAudioServiceTTS:
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format."""
# Arrange
@ -564,7 +564,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist."""
# Arrange
@ -585,7 +585,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty."""
# Arrange

View File

@ -313,7 +313,8 @@ class TestEmailDeliveryTestHandler:
recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")],
)
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
with patch.object(dify_config, "APP_WEB_URL", "http://example.com"):
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
assert subs["node_title"] == "title"
assert subs["form_content"] == "content"

View File

@ -316,7 +316,7 @@ class TestTagServiceRetrieval:
- get_tags_by_target_id: Get all tags bound to a specific target
"""
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
"""
Test retrieving tags with their binding counts.
@ -373,7 +373,7 @@ class TestTagServiceRetrieval:
# Verify database query was called
mock_db_session.query.assert_called_once()
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
"""
Test retrieving tags filtered by keyword (case-insensitive).
@ -427,7 +427,7 @@ class TestTagServiceRetrieval:
# 2. Additional WHERE clause for keyword filtering
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
"""
Test retrieving target IDs by tag IDs.
@ -483,7 +483,7 @@ class TestTagServiceRetrieval:
# Verify both queries were executed
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
"""
Test that empty tag_ids returns empty list.
@ -511,7 +511,7 @@ class TestTagServiceRetrieval:
assert results == [], "Should return empty list for empty input"
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tag_by_tag_name(self, mock_db_session, factory):
"""
Test retrieving tags by name.
@ -553,7 +553,7 @@ class TestTagServiceRetrieval:
assert len(results) == 1, "Should find exactly one tag"
assert results[0].name == tag_name, "Tag name should match"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
"""
Test that missing tag_type or tag_name returns empty list.
@ -581,7 +581,7 @@ class TestTagServiceRetrieval:
# Verify no database queries were executed
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tags_by_target_id(self, mock_db_session, factory):
"""
Test retrieving tags associated with a specific target.
@ -654,7 +654,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
@patch("services.tag_service.uuid.uuid4", autospec=True)
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
@ -743,7 +743,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
Test updating a tag name.
@ -795,7 +795,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_update_tags_raises_error_for_duplicate_name(
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
):
@ -827,7 +827,7 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"):
TagService.update_tags(args, tag_id="tag-123")
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
"""
Test that updating a non-existent tag raises NotFound.
@ -859,7 +859,7 @@ class TestTagServiceCRUD:
with pytest.raises(NotFound, match="Tag not found"):
TagService.update_tags(args, tag_id="nonexistent")
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_get_tag_binding_count(self, mock_db_session, factory):
"""
Test getting the count of bindings for a tag.
@ -895,7 +895,7 @@ class TestTagServiceCRUD:
# Verify count matches expectation
assert result == expected_count, "Binding count should match"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_delete_tag(self, mock_db_session, factory):
"""
Test deleting a tag and its bindings.
@ -951,7 +951,7 @@ class TestTagServiceCRUD:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
"""
Test that deleting a non-existent tag raises NotFound.
@ -999,7 +999,7 @@ class TestTagServiceBindings:
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test creating tag bindings.
@ -1050,7 +1050,7 @@ class TestTagServiceBindings:
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test that saving duplicate bindings is idempotent.
@ -1090,7 +1090,7 @@ class TestTagServiceBindings:
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
"""
Test deleting a tag binding.
@ -1138,7 +1138,7 @@ class TestTagServiceBindings:
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
"""
Test that deleting a non-existent binding is a no-op.
@ -1175,7 +1175,7 @@ class TestTagServiceBindings:
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
"""
Test validating that a dataset target exists.
@ -1216,7 +1216,7 @@ class TestTagServiceBindings:
mock_db_session.query.assert_called_once(), "Should query database for dataset"
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
"""
Test validating that an app target exists.
@ -1257,7 +1257,7 @@ class TestTagServiceBindings:
mock_db_session.query.assert_called_once(), "Should query database for app"
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_check_target_exists_raises_not_found_for_missing_dataset(
self, mock_db_session, mock_current_user, factory
):
@ -1289,7 +1289,7 @@ class TestTagServiceBindings:
TagService.check_target_exists("knowledge", "nonexistent")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.db.session")
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
"""
Test that missing app raises NotFound.

View File

@ -59,6 +59,11 @@ def mock_redis():
# Redis is already mocked globally in conftest.py
# Reset it for each test
redis_client.reset_mock()
redis_client.get.reset_mock()
redis_client.setex.reset_mock()
redis_client.delete.reset_mock()
redis_client.lpush.reset_mock()
redis_client.rpop.reset_mock()
redis_client.get.return_value = None
redis_client.setex.return_value = True
redis_client.delete.return_value = True