diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f..df02c584ed 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector): ) ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) docs = [] for doc, score in docs_and_scores: - score_threshold = float(kwargs.get("score_threshold") or 0.0) if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py index 65ee62b8dd..c7a4265a95 100644 --- a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -211,3 +211,16 @@ class TestCleanProcessor: text = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)" assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_remove_urls_emails_preserves_markdown_image_links(self): + """Remove plain URLs and emails while preserving markdown image links.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)" + result = CleanProcessor.clean(text, process_rule) + + assert result == "Email and remove but keep ![diagram](https://example.com/image.png)" + + def test_filter_string_returns_input_text(self): + """Test filter_string passthrough behavior.""" + processor = CleanProcessor() + assert processor.filter_string("raw text") == "raw text" diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py new file mode 100644 index 0000000000..538457ccc8 --- /dev/null +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -0,0 +1,249 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +def _doc(content: str) -> Document: + return Document(page_content=content) + + +class TestDataPostProcessor: + def test_init_sets_rerank_and_reorder_runners(self): + rerank_runner = object() + reorder_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock: + with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock: + processor = DataPostProcessor( + tenant_id="tenant-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_model={"config": "value"}, + weights={"weight": "value"}, + reorder_enabled=True, + ) + + assert processor.rerank_runner is rerank_runner + assert processor.reorder_runner is reorder_runner + rerank_mock.assert_called_once_with( + RerankMode.WEIGHTED_SCORE, + "tenant-1", + {"config": "value"}, + {"weight": "value"}, + ) + reorder_mock.assert_called_once_with(True) + + def test_invoke_applies_rerank_then_reorder(self): + original_documents = [_doc("doc-a")] + reranked_documents = [_doc("doc-b")] + reordered_documents = [_doc("doc-c")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = MagicMock() + processor.rerank_runner.run.return_value = reranked_documents + processor.reorder_runner = MagicMock() + processor.reorder_runner.run.return_value = reordered_documents + + result = processor.invoke( + query="how to test", + documents=original_documents, + score_threshold=0.3, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + processor.rerank_runner.run.assert_called_once_with( + "how to test", + original_documents, + 0.3, + 2, + "user-1", + QueryType.IMAGE_QUERY, + ) + processor.reorder_runner.run.assert_called_once_with(reranked_documents) + assert result == reordered_documents + + def test_invoke_returns_original_documents_when_no_runner_is_configured(self): + documents = [_doc("doc-a"), _doc("doc-b")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = None + processor.reorder_runner = None + + assert processor.invoke(query="query", documents=documents) == documents + + def test_get_rerank_runner_for_weighted_score(self): + weights_config = { + "vector_setting": { + "vector_weight": 0.7, + "embedding_provider_name": "provider-x", + "embedding_model_name": "embedding-y", + }, + "keyword_setting": {"keyword_weight": 0.3}, + } + expected_runner = object() + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.WEIGHTED_SCORE, + tenant_id="tenant-1", + reranking_model=None, + weights=weights_config, + ) + + assert result is expected_runner + kwargs = factory_mock.call_args.kwargs + assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE + assert kwargs["tenant_id"] == "tenant-1" + assert kwargs["weights"].vector_setting.vector_weight == 0.7 + assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x" + assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y" + assert kwargs["weights"].keyword_setting.keyword_weight == 0.3 + + def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + reranking_model = { + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + } + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock: + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner" + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model=reranking_model, + weights=None, + ) + + assert result is None + model_mock.assert_called_once_with("tenant-1", reranking_model) + factory_mock.assert_not_called() + + def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + expected_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance): + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + }, + weights=None, + ) + + assert result is expected_runner + factory_mock.assert_called_once_with( + runner_type=RerankMode.RERANKING_MODEL, + rerank_model_instance=model_instance, + ) + + def test_get_rerank_runner_returns_none_for_unsupported_mode(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None + assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None + + def test_get_reorder_runner_by_flag(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert isinstance(processor._get_reorder_runner(True), ReorderRunner) + assert processor._get_reorder_runner(False) is None + + def test_get_rerank_model_instance_returns_none_when_config_is_missing(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + assert processor._get_rerank_model_instance("tenant-1", None) is None + + def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + with pytest.raises(KeyError, match="reranking_model_name"): + processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) + + manager_instance.get_model_instance.assert_not_called() + + def test_get_rerank_model_instance_success(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + manager_instance.get_model_instance.return_value = model_instance + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is model_instance + manager_instance.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider-x", + model_type=ModelType.RERANK, + model="reranker-1", + ) + + def test_get_rerank_model_instance_handles_authorization_error(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is None + + +class TestReorderRunner: + def test_run_reorders_even_sized_document_list(self): + documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")] + + reordered = ReorderRunner().run(documents) + + assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"] + + def test_run_handles_odd_sized_and_empty_document_lists(self): + odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")] + runner = ReorderRunner() + + odd_reordered = runner.run(odd_documents) + + assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"] + assert runner.run([]) == [] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py new file mode 100644 index 0000000000..795a325a6b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -0,0 +1,414 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.keyword.jieba.jieba as jieba_module +from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default +from core.rag.models.document import Document + + +class _DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Field: + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return ("eq", self._name, other) + + def in_(self, values): + return ("in", self._name, tuple(values)) + + +class _FakeQuery: + def __init__(self): + self.where_calls: list[tuple] = [] + + def where(self, *conditions): + self.where_calls.append(conditions) + return self + + +class _FakeExecuteResult: + def __init__(self, segments: list[SimpleNamespace]): + self._segments = segments + + def scalars(self): + return self + + def all(self): + return self._segments + + +class _FakeSelect: + def __init__(self): + self.where_conditions: tuple | None = None + + def where(self, *conditions): + self.where_conditions = conditions + return self + + +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): + return SimpleNamespace( + data_source_type=data_source_type, + keyword_table_dict=keyword_table_dict, + keyword_table="", + ) + + +def _dataset(dataset_keyword_table=None, keyword_number=None): + return SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + keyword_number=keyword_number, + dataset_keyword_table=dataset_keyword_table, + ) + + +@pytest.fixture +def patched_runtime(monkeypatch): + session = MagicMock() + db = SimpleNamespace(session=session) + storage = MagicMock() + lock = MagicMock(return_value=_DummyLock()) + redis_client = SimpleNamespace(lock=lock) + + monkeypatch.setattr(jieba_module, "db", db) + monkeypatch.setattr(jieba_module, "storage", storage) + monkeypatch.setattr(jieba_module, "redis_client", redis_client) + + return SimpleNamespace(session=session, storage=storage, lock=lock) + + +def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime): + dataset = _dataset(_dataset_keyword_table(), keyword_number=2) + keyword = Jieba(dataset) + handler = MagicMock() + handler.extract_keywords.return_value = {"kw1", "kw2"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + result = keyword.create( + [ + Document(page_content="alpha", metadata={"doc_id": "node-1"}), + SimpleNamespace(page_content="ignored", metadata=None), + ] + ) + + assert result is keyword + keyword._update_segment_keywords.assert_called_once() + call_args = keyword._update_segment_keywords.call_args.args + assert call_args[0] == "dataset-1" + assert call_args[1] == "node-1" + assert set(call_args[2]) == {"kw1", "kw2"} + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["kw1"] == {"node-1"} + assert saved_table["kw2"] == {"node-1"} + patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600) + + +def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + texts = [ + Document(page_content="extract-this", metadata={"doc_id": "node-1"}), + Document(page_content="use-manual", metadata={"doc_id": "node-2"}), + ] + keyword.add_texts(texts, keywords_list=[[], ["manual"]]) + + assert keyword._update_segment_keywords.call_count == 2 + first_call = keyword._update_segment_keywords.call_args_list[0].args + second_call = keyword._update_segment_keywords.call_args_list[1].args + assert set(first_call[2]) == {"auto"} + assert second_call[2] == ["manual"] + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1)) + handler = MagicMock() + handler.extract_keywords.return_value = {"from-extractor"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})]) + + handler.extract_keywords.assert_called_once_with("content", 1) + assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"} + + +def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + assert keyword.text_exists("node-1") is False + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + assert keyword.text_exists("node-2") is True + assert keyword.text_exists("node-x") is False + + +def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"]) + keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}}) + + +def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_not_called() + keyword._save_dataset_keyword_table.assert_called_once_with(None) + + +def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + document_id = _Field("document_id") + + keyword = Jieba(_dataset(_dataset_keyword_table())) + query_stmt = _FakeQuery() + patched_runtime.session.query.return_value = query_stmt + patched_runtime.session.execute.return_value = _FakeExecuteResult( + [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] + ) + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) + + documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) + + assert len(query_stmt.where_calls) == 2 + assert len(documents) == 1 + assert documents[0].page_content == "segment-content" + assert documents[0].metadata["doc_id"] == "node-2" + assert documents[0].metadata["doc_hash"] == "hash-2" + + +def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime): + db_keyword = _dataset_keyword_table(data_source_type="database") + file_keyword = _dataset_keyword_table(data_source_type="object_storage") + + keyword_db = Jieba(_dataset(db_keyword)) + keyword_db.delete() + patched_runtime.storage.delete.assert_not_called() + + keyword_file = Jieba(_dataset(file_keyword)) + keyword_file.delete() + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + assert patched_runtime.session.delete.call_count == 2 + assert patched_runtime.session.commit.call_count == 2 + + +def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="database") + keyword = Jieba(_dataset(dataset_keyword_table)) + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table + assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table + patched_runtime.session.commit.assert_called_once() + + +def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="file") + keyword = Jieba(_dataset(dataset_keyword_table)) + patched_runtime.storage.exists.return_value = True + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + patched_runtime.storage.save.assert_called_once() + save_args = patched_runtime.storage.save.call_args.args + assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt" + assert isinstance(save_args[1], bytes) + + +def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime): + existing = _dataset_keyword_table( + keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}} + ) + keyword = Jieba(_dataset(existing)) + assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]} + + missing_payload = _dataset_keyword_table(keyword_table_dict=None) + keyword_with_missing_payload = Jieba(_dataset(missing_payload)) + assert keyword_with_missing_payload._get_dataset_keyword_table() == {} + + +def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime): + created_tables: list[SimpleNamespace] = [] + + def _fake_dataset_keyword_table(**kwargs): + kwargs.setdefault("keyword_table", "") + kwargs.setdefault("keyword_table_dict", None) + table = SimpleNamespace(**kwargs) + created_tables.append(table) + return table + + keyword = Jieba(_dataset(dataset_keyword_table=None)) + monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table) + monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database") + + result = keyword._get_dataset_keyword_table() + + assert result == {} + assert len(created_tables) == 1 + assert created_tables[0].dataset_id == "dataset-1" + assert created_tables[0].data_source_type == "database" + assert '"index_id":"dataset-1"' in created_tables[0].keyword_table + patched_runtime.session.add.assert_called_once_with(created_tables[0]) + patched_runtime.session.commit.assert_called_once() + + +def test_add_and_delete_ids_from_keyword_table_helpers(): + keyword = Jieba(_dataset(_dataset_keyword_table())) + keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}} + + updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"]) + assert updated["kw1"] == {"node-1", "node-3"} + assert updated["kw3"] == {"node-3"} + + deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"]) + assert "kw3" not in deleted + assert "kw1" not in deleted + assert deleted["kw2"] == {"node-2"} + + +def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + handler = MagicMock() + handler.extract_keywords.return_value = ["kw-a", "kw-b"] + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + + ranked_ids = keyword._retrieve_ids_by_query( + {"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}}, + "query", + k=1, + ) + + assert ranked_ids == ["node-2"] + + +def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + + keyword = Jieba(_dataset(_dataset_keyword_table())) + segment = SimpleNamespace(keywords=[]) + patched_runtime.session.scalar.return_value = segment + + keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"]) + + assert segment.keywords == ["kw1", "kw2"] + patched_runtime.session.add.assert_called_once_with(segment) + patched_runtime.session.commit.assert_called_once() + + patched_runtime.session.reset_mock() + patched_runtime.session.scalar.return_value = None + + keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"]) + + patched_runtime.session.add.assert_not_called() + patched_runtime.session.commit.assert_not_called() + + +def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.create_segment_keywords("node-1", ["kw"]) + keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"]) + keyword._save_dataset_keyword_table.assert_called_once() + + keyword._save_dataset_keyword_table.reset_mock() + keyword.update_segment_keywords_index("node-2", ["kw2"]) + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None) + second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None) + + keyword.multi_create_segment_keywords( + [ + {"segment": first_segment, "keywords": ["manual"]}, + {"segment": second_segment, "keywords": []}, + ] + ) + + assert first_segment.keywords == ["manual"] + assert second_segment.keywords == ["auto"] + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["manual"] == {"node-1"} + assert saved_table["auto"] == {"node-2"} + + +def test_set_orjson_default_and_dumps_with_sets(): + assert set(set_orjson_default({"a", "b"})) == {"a", "b"} + + with pytest.raises(TypeError, match="is not JSON serializable"): + set_orjson_default(("not", "a", "set")) + + payload = {"items": {"a", "b"}} + json_payload = dumps_with_sets(payload) + decoded = json.loads(json_payload) + assert set(decoded["items"]) == {"a", "b"} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py new file mode 100644 index 0000000000..a4586c141b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py @@ -0,0 +1,142 @@ +import sys +import types +from types import SimpleNamespace + +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +class _DummyTFIDF: + def __init__(self): + self.stop_words = set() + + @staticmethod + def extract_tags(sentence: str, top_k: int | None = 20, **kwargs): + return ["alpha_beta", "during", "gamma"] + + +def _install_fake_jieba_modules( + monkeypatch, + analyse_module: types.ModuleType, + jieba_attrs: dict[str, object] | None = None, + tfidf_module: types.ModuleType | None = None, +): + jieba_module = types.ModuleType("jieba") + jieba_module.__path__ = [] + if jieba_attrs: + for key, value in jieba_attrs.items(): + setattr(jieba_module, key, value) + + jieba_module.analyse = analyse_module + analyse_module.__package__ = "jieba" + + monkeypatch.setitem(sys.modules, "jieba", jieba_module) + monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module) + if tfidf_module is not None: + monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module) + else: + monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False) + + +def test_init_uses_existing_default_tfidf(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + default_tfidf = _DummyTFIDF() + analyse_module.default_tfidf = default_tfidf + + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert handler._tfidf is default_tfidf + assert handler._tfidf.stop_words == STOPWORDS + + +def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + class _TFIDFFactory(_DummyTFIDF): + pass + + analyse_module.TFIDF = _TFIDFFactory + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _TFIDFFactory) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + tfidf_module = types.ModuleType("jieba.analyse.tfidf") + + class _ImportedTFIDF(_DummyTFIDF): + pass + + tfidf_module.TFIDF = _ImportedTFIDF + _install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _ImportedTFIDF) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1) + + assert fallback_keywords == ["two"] + + +def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]}) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["x"] + + +def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules( + monkeypatch, + analyse_module, + jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])}, + ) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["foo"] + + +def test_extract_keywords_expands_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"]) + + keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3) + + assert "alpha-beta" in keywords + assert "alpha" in keywords + assert "beta" in keywords + assert "during" in keywords + assert "gamma" in keywords + + +def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + + expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"}) + + assert "alpha-during-beta" in expanded + assert "alpha" in expanded + assert "beta" in expanded + assert "during" not in expanded diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py new file mode 100644 index 0000000000..1b1541ddd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py @@ -0,0 +1,6 @@ +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +def test_stopwords_loaded(): + assert "during" in STOPWORDS + assert "the" in STOPWORDS diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py new file mode 100644 index 0000000000..55e22aea0a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document + + +class _KeywordThatRaises(BaseKeyword): + def create(self, texts: list[Document], **kwargs): + return super().create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + return super().add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return super().text_exists(id) + + def delete_by_ids(self, ids: list[str]): + return super().delete_by_ids(ids) + + def delete(self): + return super().delete() + + def search(self, query: str, **kwargs): + return super().search(query, **kwargs) + + +class _KeywordForHelpers(BaseKeyword): + def __init__(self, dataset, existing_ids: set[str] | None = None): + super().__init__(dataset) + self._existing_ids = existing_ids or set() + + def create(self, texts: list[Document], **kwargs): + return self + + def add_texts(self, texts: list[Document], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete(self): + return None + + def search(self, query: str, **kwargs): + return [] + + +def test_abstract_methods_raise_not_implemented(): + keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1")) + + with pytest.raises(NotImplementedError): + keyword.create([]) + + with pytest.raises(NotImplementedError): + keyword.add_texts([]) + + with pytest.raises(NotImplementedError): + keyword.text_exists("doc-1") + + with pytest.raises(NotImplementedError): + keyword.delete_by_ids(["doc-1"]) + + with pytest.raises(NotImplementedError): + keyword.delete() + + with pytest.raises(NotImplementedError): + keyword.search("query") + + +def test_filter_duplicate_texts_removes_existing_doc_ids(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"}) + texts = [ + Document(page_content="keep", metadata={"doc_id": "keep"}), + Document(page_content="duplicate", metadata={"doc_id": "duplicate"}), + SimpleNamespace(page_content="without-metadata", metadata=None), + ] + + filtered = keyword._filter_duplicate_texts(texts) + + assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"] + assert any(text.metadata is None for text in filtered) + + +def test_get_uuids_returns_only_docs_with_metadata(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1")) + texts = [ + Document(page_content="doc-1", metadata={"doc_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "doc-2"}), + SimpleNamespace(page_content="doc-3", metadata=None), + ] + + assert keyword._get_uuids(texts) == ["doc-1", "doc-2"] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py new file mode 100644 index 0000000000..0d969a3270 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py @@ -0,0 +1,84 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.keyword.keyword_type import KeyWordType +from core.rag.models.document import Document + + +def test_get_keyword_factory_returns_jieba_factory(monkeypatch): + fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba") + + class FakeJieba: + pass + + fake_module.Jieba = FakeJieba + monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module) + + assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba + + +def test_get_keyword_factory_raises_for_unsupported_type(): + with pytest.raises(ValueError, match="Keyword store unsupported is not supported"): + Keyword.get_keyword_factory("unsupported") + + +def test_keyword_initialization_uses_configured_factory(monkeypatch): + dataset = SimpleNamespace(id="dataset-1") + fake_processor = MagicMock() + + monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA) + monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor)) + + keyword = Keyword(dataset) + + assert keyword._keyword_processor is fake_processor + + +def test_keyword_methods_forward_to_processor(): + processor = MagicMock() + processor.text_exists.return_value = True + processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})] + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = processor + + docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})] + keyword.create(docs, foo="bar") + keyword.add_texts(docs, batch=True) + assert keyword.text_exists("doc-1") is True + keyword.delete_by_ids(["doc-1"]) + keyword.delete() + assert keyword.search("query", top_k=1) == processor.search.return_value + + processor.create.assert_called_once_with(docs, foo="bar") + processor.add_texts.assert_called_once_with(docs, batch=True) + processor.text_exists.assert_called_once_with("doc-1") + processor.delete_by_ids.assert_called_once_with(["doc-1"]) + processor.delete.assert_called_once() + processor.search.assert_called_once_with("query", top_k=1) + + +def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes(): + class Processor: + value = 1 + + @staticmethod + def custom(): + return "ok" + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = Processor() + + assert keyword.custom() == "ok" + + with pytest.raises(AttributeError): + _ = keyword.value + + keyword._keyword_processor = None + with pytest.raises(AttributeError): + _ = keyword.missing_method diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py new file mode 100644 index 0000000000..8c1e4e478b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -0,0 +1,1174 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +import pytest + +from core.rag.datasource import retrieval_service as retrieval_service_module +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + This helper function standardizes document creation across tests, + ensuring consistent structure and reducing code duplication. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + + Example: + >>> doc = create_mock_document("Python is great", "doc1", score=0.95) + >>> assert doc.metadata["score"] == 0.95 + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + # Merge additional metadata if provided + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +class _ImmediateFuture: + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception + self.cancel_called = False + + def exception(self) -> Exception | None: + return self._exception + + def cancel(self) -> None: + self.cancel_called = True + + +class _ImmediateExecutor: + def __init__(self) -> None: + self.futures: list[_ImmediateFuture] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + try: + fn(*args, **kwargs) + future = _ImmediateFuture() + except Exception as exc: # pragma: no cover - only for defensive parity with Future semantics + future = _ImmediateFuture(exc) + self.futures.append(future) + return future + + +class _FakeExecuteScalarResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + +class _FakeExecuteResult: + def __init__(self, data: list) -> None: + self._data = data + + def scalars(self) -> _FakeExecuteScalarResult: + return _FakeExecuteScalarResult(self._data) + + +class _FakeSummaryQuery: + def __init__(self, summaries: list) -> None: + self._summaries = summaries + + def filter(self, *args, **kwargs): + return self + + def all(self) -> list: + return self._summaries + + +class _FakeSession: + def __init__(self, execute_payloads: list[list], summaries: list) -> None: + self._payloads = list(execute_payloads) + self._summaries = summaries + + def execute(self, stmt): + data = self._payloads.pop(0) if self._payloads else [] + return _FakeExecuteResult(data) + + def query(self, model): + return _FakeSummaryQuery(self._summaries) + + +class _FakeSessionContext: + def __init__(self, session: _FakeSession) -> None: + self._session = session + + def __enter__(self) -> _FakeSession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class _SimpleRetrievalChildChunk: + def __init__(self, id: str, content: str, score: float, position: int) -> None: + self.id = id + self.content = content + self.score = score + self.position = position + + +class _SimpleRetrievalSegment: + def __init__( + self, + segment, + child_chunks: list[_SimpleRetrievalChildChunk] | None = None, + score: float | None = None, + files: list[dict[str, str | int]] | None = None, + summary: str | None = None, + ) -> None: + self.segment = segment + self.child_chunks = child_chunks + self.score = score + self.files = files + self.summary = summary + + +class TestRetrievalServiceInternals: + @pytest.fixture + def internal_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = "dataset-id" + dataset.tenant_id = "tenant-id" + dataset.is_multimodal = False + dataset.doc_form = IndexStructureType.PARENT_CHILD_INDEX + return dataset + + @pytest.fixture + def internal_flask_app(self): + app = MagicMock() + app.app_context.return_value.__enter__ = Mock() + app.app_context.return_value.__exit__.return_value = False + return app + + def test_retrieve_with_attachment_ids_only(self, monkeypatch, internal_dataset): + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset", return_value=internal_dataset), + patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") as mock_retrieve, + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + def side_effect( + flask_app, + retrieval_method, + dataset, + all_documents, + exceptions, + query=None, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + ): + all_documents.append(create_mock_document(f"content-{attachment_id}", attachment_id or "none", 0.9)) + + mock_retrieve.side_effect = side_effect + + results = RetrievalService.retrieve( + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset_id=internal_dataset.id, + query="", + attachment_ids=["att-1", "att-2"], + ) + + assert len(results) == 2 + assert {doc.metadata["doc_id"] for doc in results} == {"att-1", "att-2"} + assert mock_retrieve.call_count == 2 + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + mock_validate.return_value = "validated-condition" + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, + ) + + assert results == expected_documents + mock_validate.assert_called_once() + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition="validated-condition", + ) + + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_returns_empty_when_dataset_not_found(self, mock_scalar): + mock_scalar.return_value = None + + results = RetrievalService.external_retrieve(dataset_id="missing", query="q") + + assert results == [] + + @patch("core.rag.datasource.retrieval_service.Session") + def test_get_dataset_queries_by_id(self, mock_session_class): + expected_dataset = Mock(spec=Dataset) + mock_session = Mock() + mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session_class.return_value.__enter__.return_value = mock_session + + with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): + result = RetrievalService._get_dataset("dataset-123") + + assert result == expected_dataset + mock_session.query.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_success(self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.return_value = [create_mock_document("keyword-content", "kw-1", 0.91)] + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "with quotes"', + top_k=5, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + keyword_instance.search.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_dataset_missing(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.side_effect = RuntimeError("keyword failed") + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["keyword failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [create_mock_document("vector-content", "vec-1", 0.7)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + document_ids_filter=["doc-1"], + query_type=QueryType.TEXT_QUERY, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_vector.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_non_multimodal_returns_early( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-1", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == [] + assert exceptions == [] + vector_instance.search_by_file.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_with_vision_reranking( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + reranked_docs = [create_mock_document("image-content-reranked", "img-doc", 0.97)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = True + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + model_manager.check_model_support_vision.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_without_vision_support( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = [create_mock_document("unused", "unused", 0.1)] + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = False + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == original_docs + assert exceptions == [] + processor_instance.invoke.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_with_reranking_non_multimodal( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("vector-content", "vec-doc", 0.62)] + reranked_docs = [create_mock_document("vector-content-reranked", "vec-doc", 0.89)] + + vector_instance = Mock() + vector_instance.search_by_vector.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_appends_exception_when_vector_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.side_effect = RuntimeError("vector failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == [] + assert exceptions == ["vector failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = [create_mock_document("fulltext", "ft-1", 0.68)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "x"', + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_full_text.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_with_reranking( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("fulltext", "ft-1", 0.68)] + reranked_docs = [create_mock_document("fulltext-reranked", "ft-1", 0.9)] + + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_dataset_not_found(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.side_effect = RuntimeError("fulltext failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["fulltext failed"] + + def test_format_retrieval_documents_with_empty_input_returns_empty_list(self): + assert RetrievalService.format_retrieval_documents([]) == [] + + def test_format_retrieval_documents_without_document_id_returns_empty_list(self): + documents = [Document(page_content="content", metadata={"doc_id": "doc-1", "score": 0.4}, provider="dify")] + + assert RetrievalService.format_retrieval_documents(documents) == [] + + def test_format_retrieval_documents_with_parent_child_summary_and_attachments(self, monkeypatch): + dataset_doc_parent = SimpleNamespace( + id="doc-parent", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + dataset_doc_text = SimpleNamespace(id="doc-text", doc_form="paragraph", dataset_id="dataset-id") + dataset_doc_parent_summary = SimpleNamespace( + id="doc-parent-summary", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + + dataset_query = Mock() + dataset_query.where.return_value.options.return_value.all.return_value = [ + dataset_doc_parent, + dataset_doc_text, + dataset_doc_parent_summary, + ] + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) + monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) + monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) + + input_documents = [ + Document( + page_content="child node content", + metadata={"document_id": "doc-parent", "doc_id": "child-node-1", "score": 0.7}, + provider="dify", + ), + Document( + page_content="parent image", + metadata={ + "document_id": "doc-parent", + "doc_id": "attach-node-1", + "doc_type": DocType.IMAGE, + "score": 0.8, + }, + provider="dify", + ), + Document( + page_content="text index node", + metadata={"document_id": "doc-text", "doc_id": "index-node-1", "score": 0.6}, + provider="dify", + ), + Document( + page_content="text image node", + metadata={ + "document_id": "doc-text", + "doc_id": "attach-text-1", + "doc_type": DocType.IMAGE, + "score": 0.65, + }, + provider="dify", + ), + Document( + page_content="summary candidate 1", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-1", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.9", + }, + provider="dify", + ), + Document( + page_content="summary candidate 2", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-2", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.95", + }, + provider="dify", + ), + Document( + page_content="invalid score summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-invalid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "invalid", + }, + provider="dify", + ), + Document( + page_content="valid parent summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-valid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "0.4", + }, + provider="dify", + ), + ] + + child_chunk = SimpleNamespace( + id="child-chunk-1", + segment_id="segment-parent", + index_node_id="child-node-1", + content="child details", + position=2, + ) + segment_parent = SimpleNamespace(id="segment-parent", document_id="doc-parent", index_node_id="parent-node") + segment_text = SimpleNamespace(id="segment-text", document_id="doc-text", index_node_id="index-node-1") + segment_summary = SimpleNamespace(id="segment-summary", document_id="doc-text", index_node_id="summary-node") + segment_parent_summary = SimpleNamespace( + id="segment-parent-summary", + document_id="doc-parent-summary", + index_node_id="summary-parent-node", + ) + + fake_session = _FakeSession( + execute_payloads=[ + [child_chunk], + [segment_text], + [segment_parent, segment_text], + [segment_summary, segment_parent_summary], + ], + summaries=[ + SimpleNamespace(chunk_id="segment-summary", summary_content="summary for text"), + SimpleNamespace(chunk_id="segment-parent-summary", summary_content="summary for parent"), + ], + ) + monkeypatch.setattr( + retrieval_service_module.session_factory, + "create_session", + lambda: _FakeSessionContext(fake_session), + ) + monkeypatch.setattr( + RetrievalService, + "get_segment_attachment_infos", + lambda attachment_ids, session: [ + { + "attachment_id": "attach-node-1", + "attachment_info": { + "id": "attach-node-1", + "name": "img-parent", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://parent", + "size": 11, + }, + "segment_id": "segment-parent", + }, + { + "attachment_id": "attach-text-1", + "attachment_info": { + "id": "attach-text-1", + "name": "img-text", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://text", + "size": 22, + }, + "segment_id": "segment-text", + }, + ], + ) + + result = RetrievalService.format_retrieval_documents(input_documents) + + assert len(result) == 4 + result_by_segment_id = {item.segment.id: item for item in result} + assert result_by_segment_id["segment-summary"].score == pytest.approx(0.95) + assert result_by_segment_id["segment-summary"].summary == "summary for text" + assert result_by_segment_id["segment-parent"].score == pytest.approx(0.8) + assert result_by_segment_id["segment-parent"].files is not None + assert len(result_by_segment_id["segment-parent"].child_chunks or []) == 1 + assert result_by_segment_id["segment-text"].score == pytest.approx(0.65) + assert result_by_segment_id["segment-parent-summary"].score == pytest.approx(0.4) + assert result_by_segment_id["segment-parent-summary"].summary == "summary for parent" + assert result_by_segment_id["segment-parent-summary"].child_chunks == [] + + def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): + rollback = Mock() + monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) + + documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] + + with pytest.raises(RuntimeError, match="db error"): + RetrievalService.format_retrieval_documents(documents) + + rollback.assert_called_once() + + def test_retrieve_internal_returns_early_without_query_or_attachment(self, internal_dataset, internal_flask_app): + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=exceptions, + query=None, + attachment_id=None, + ) + + assert all_documents == [] + assert exceptions == [] + + def test_retrieve_internal_cancels_futures_when_future_has_exception(self, internal_dataset, internal_flask_app): + future_error = Mock() + future_error.exception.return_value = RuntimeError("future failed") + future_ok = Mock() + future_ok.exception.return_value = None + + with ( + patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor, + patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + return_value=[future_error, future_ok], + ), + ): + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = [future_error, future_ok] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=[], + query="query", + attachment_id="file-1", + ) + + future_error.cancel.assert_called() + future_ok.cancel.assert_called() + + def test_retrieve_internal_raises_value_error_when_exceptions_exist( + self, monkeypatch, internal_dataset, internal_flask_app + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + with patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") as mock_keyword_search: + mock_keyword_search.side_effect = lambda *args, **kwargs: None + with pytest.raises(ValueError, match="keyword error"): + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.KEYWORD_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=["keyword error"], + query="query", + ) + + def test_retrieve_internal_hybrid_weighted_attachment_flow(self, monkeypatch, internal_dataset, internal_flask_app): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + text_doc = create_mock_document("text", "text-doc", 0.81) + image_doc = create_mock_document("image", "image-doc", 0.72) + fulltext_doc = create_mock_document("full", "full-doc", 0.65) + processed_doc = create_mock_document("processed", "processed-doc", 0.99) + + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") as mock_embedding_search, + patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search") as mock_fulltext, + patch("core.rag.datasource.retrieval_service.DataPostProcessor") as mock_processor_class, + ): + + def embedding_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + query_type=QueryType.TEXT_QUERY, + ): + if query_type == QueryType.IMAGE_QUERY: + all_documents.append(image_doc) + else: + all_documents.append(text_doc) + + mock_embedding_search.side_effect = embedding_side_effect + + def fulltext_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + ): + all_documents.append(fulltext_doc) + + mock_fulltext.side_effect = fulltext_side_effect + processor_instance = Mock() + processor_instance.invoke.return_value = [processed_doc] + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.HYBRID_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=[], + query="query", + attachment_id="file-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + top_k=3, + ) + + assert len(all_documents) == 4 + assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_info_success(self, mock_sign): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = binding + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result == { + "attachment_info": { + "id": "upload-1", + "name": "file-name", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + mock_sign.assert_called_once_with("upload-1", "png") + + def test_get_segment_attachment_info_returns_none_when_binding_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = None + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): + upload_query = Mock() + upload_query.where.return_value.first.return_value = None + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): + upload_query = Mock() + upload_query.where.return_value.all.return_value = [] + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + def test_get_segment_attachment_infos_returns_empty_when_bindings_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_infos_success(self, mock_sign): + upload_file_1 = SimpleNamespace( + id="upload-1", + name="file-1", + extension="png", + mime_type="image/png", + size=42, + ) + upload_file_2 = SimpleNamespace( + id="upload-2", + name="file-2", + extension="jpg", + mime_type="image/jpeg", + size=99, + ) + binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") + + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [binding] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) + + assert result == [ + { + "attachment_id": "upload-1", + "attachment_info": { + "id": "upload-1", + "name": "file-1", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + ] + mock_sign.assert_has_calls( + [ + call("upload-1", "png"), + call("upload-2", "jpg"), + ] + ) + assert mock_sign.call_count == 2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py new file mode 100644 index 0000000000..e063a49f22 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py @@ -0,0 +1,74 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory + + +def test_validate_distance_function_accepts_supported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + assert factory._validate_distance_function("cosine") == "cosine" + assert factory._validate_distance_function("euclidean") == "euclidean" + + +def test_validate_distance_function_rejects_unsupported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + with pytest.raises(ValueError, match="Invalid distance function"): + factory._validate_distance_function("dot_product") + + +def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" + + +def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", + index_struct_dict=None, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + vector_cls.assert_called_once() + assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2" + assert dataset.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py new file mode 100644 index 0000000000..545565cdf4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from core.rag.models.document import Document + + +def test_init_prefers_openapi_when_api_config_is_provided(): + api_config = AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls: + vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None) + + assert vector.analyticdb_vector == "openapi_runner" + openapi_cls.assert_called_once_with("COLLECTION", api_config) + + +def test_init_uses_sql_implementation_when_api_config_is_missing(): + sql_config = AnalyticdbVectorBySqlConfig( + host="localhost", + port=5432, + account="account", + account_password="password", + min_connection=1, + max_connection=2, + namespace="dify", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls: + vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config) + + assert vector.analyticdb_vector == "sql_runner" + sql_cls.assert_called_once_with("COLLECTION", sql_config) + + +def test_init_raises_when_both_configs_are_missing(): + with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"): + AnalyticdbVector("COLLECTION", api_config=None, sql_config=None) + + +def test_vector_methods_delegate_to_underlying_implementation(): + runner = MagicMock() + runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})] + runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})] + runner.text_exists.return_value = True + + vector = AnalyticdbVector.__new__(AnalyticdbVector) + vector.analyticdb_vector = runner + + texts = [Document(page_content="hello", metadata={"doc_id": "d1"})] + vector.create(texts=texts, embeddings=[[0.1, 0.2]]) + vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]]) + assert vector.text_exists("d1") is True + vector.delete_by_ids(["d1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value + assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value + vector.delete() + + runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) + runner.delete_by_ids.assert_called_once_with(["d1"]) + runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") + runner.delete.assert_called_once() + + +def test_get_type_is_analyticdb(): + vector = AnalyticdbVector.__new__(AnalyticdbVector) + assert vector.get_type() == "analyticdb" + + +def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "auto_collection" + assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig) + assert args[2] is None + assert dataset.index_struct is not None + + +def test_factory_builds_sql_config_when_host_is_present(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "existing" + assert args[1] is None + assert isinstance(args[2], AnalyticdbVectorBySqlConfig) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py new file mode 100644 index 0000000000..45777774d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -0,0 +1,384 @@ +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.models.document import Document + + +def _request_class(name: str): + class _Request: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + _Request.__name__ = name + return _Request + + +def _install_openapi_stubs(monkeypatch): + gpdb_package = types.ModuleType("alibabacloud_gpdb20160503") + gpdb_package.__path__ = [] + gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models") + for class_name in [ + "InitVectorDatabaseRequest", + "DescribeNamespaceRequest", + "CreateNamespaceRequest", + "DescribeCollectionRequest", + "CreateCollectionRequest", + "UpsertCollectionDataRequestRows", + "UpsertCollectionDataRequest", + "QueryCollectionDataRequest", + "DeleteCollectionDataRequest", + "DeleteCollectionRequest", + ]: + setattr(gpdb_models, class_name, _request_class(class_name)) + + class _Client: + def __init__(self, config): + self.config = config + + gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client") + gpdb_client.Client = _Client + gpdb_package.models = gpdb_models + + tea_openapi = types.ModuleType("alibabacloud_tea_openapi") + tea_openapi.__path__ = [] + tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models") + + class OpenApiConfig: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + tea_openapi_models.Config = OpenApiConfig + tea_openapi.models = tea_openapi_models + + tea_package = types.ModuleType("Tea") + tea_package.__path__ = [] + tea_exceptions = types.ModuleType("Tea.exceptions") + + class TeaError(Exception): + def __init__(self, status_code=None, **kwargs): + super().__init__("TeaException") + status_code = kwargs.get("statusCode", status_code) + self.statusCode = status_code + self.status_code = status_code + + tea_exceptions.TeaException = TeaError + tea_package.exceptions = tea_exceptions + + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models) + monkeypatch.setitem(sys.modules, "Tea", tea_package) + monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions) + + return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig) + + +def _config() -> AnalyticdbVectorOpenAPIConfig: + return AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("access_key_id", "", "ANALYTICDB_KEY_ID"), + ("access_key_secret", "", "ANALYTICDB_KEY_SECRET"), + ("region_id", "", "ANALYTICDB_REGION_ID"), + ("instance_id", "", "ANALYTICDB_INSTANCE_ID"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"), + ], +) +def test_openapi_config_validation(field, value, error_message): + values = _config().model_dump() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorOpenAPIConfig.model_validate(values) + + +def test_openapi_config_to_client_params(): + config = _config() + params = config.to_analyticdb_client_params() + + assert params["access_key_id"] == "ak" + assert params["access_key_secret"] == "sk" + assert params["region_id"] == "cn-hangzhou" + assert params["read_timeout"] == 60000 + + +def test_init_creates_openapi_client_and_runs_initialize(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + initialize_mock = MagicMock() + monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock) + + vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config()) + + assert vector._collection_name == "collection_1" + assert isinstance(vector._client_config, stubs.OpenApiConfig) + assert vector._client_config.user_agent == "dify" + assert vector._client_config.access_key_id == "ak" + assert vector._client.config is vector._client_config + initialize_mock.assert_called_once_with() + + +def test_initialize_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + vector._create_namespace_if_not_exists.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + vector._create_namespace_if_not_exists.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_initialize_vector_database_calls_openapi_client(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._initialize_vector_database() + + request = vector._client.init_vector_database.call_args.args[0] + assert request.dbinstance_id == "instance-1" + assert request.region_id == "cn-hangzhou" + assert request.manager_account == "account" + assert request.manager_account_password == "password" + + +def test_create_namespace_creates_when_namespace_not_found(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404) + + vector._create_namespace_if_not_exists() + + vector._client.create_namespace.assert_called_once() + + +def test_create_namespace_raises_on_unexpected_api_error(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create namespace"): + vector._create_namespace_if_not_exists() + + +def test_create_namespace_noop_when_namespace_exists(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._create_namespace_if_not_exists() + + vector._client.describe_namespace.assert_called_once() + vector._client.create_namespace.assert_not_called() + + +def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.create_collection.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.describe_collection.assert_not_called() + vector._client.create_collection.assert_not_called() + + +def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create collection collection_1"): + vector._create_collection_if_not_exists(embedding_dimension=512) + + +def test_openapi_add_delete_and_search_methods(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + documents = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + embeddings = [[0.1, 0.2], [0.2, 0.3]] + vector.add_texts(documents, embeddings) + + upsert_request = vector._client.upsert_collection_data.call_args.args[0] + assert upsert_request.collection == "collection_1" + assert len(upsert_request.rows) == 1 + + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()])) + ) + assert vector.text_exists("d1") is True + + vector.delete_by_ids(["d1", "d2"]) + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')" + + vector.delete_by_metadata_field("document_id", "doc-1") + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'" + + match_high = SimpleNamespace( + score=0.9, + metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"}, + values=SimpleNamespace(value=[1.0, 2.0]), + ) + match_low = SimpleNamespace( + score=0.1, + metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"}, + values=SimpleNamespace(value=[3.0, 4.0]), + ) + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high])) + ) + + docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + assert len(docs_by_vector) == 1 + assert docs_by_vector[0].page_content == "high" + assert docs_by_vector[0].metadata["score"] == 0.9 + + docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2) + assert len(docs_by_text) == 1 + assert docs_by_text[0].page_content == "high" + + +def test_text_exists_returns_false_when_matches_empty(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[])) + ) + + assert vector.text_exists("missing-id") is False + + +def test_openapi_delete_success(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector.delete() + vector._client.delete_collection.assert_called_once() + + +def test_openapi_delete_propagates_errors(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.delete_collection.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.delete() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py new file mode 100644 index 0000000000..8f1206696b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -0,0 +1,427 @@ +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import psycopg2.errors +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( + AnalyticdbVectorBySql, + AnalyticdbVectorBySqlConfig, +) +from core.rag.models.document import Document + + +def _config_values() -> dict: + return { + "host": "localhost", + "port": 5432, + "account": "account", + "account_password": "password", + "min_connection": 1, + "max_connection": 2, + "namespace": "dify", + } + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("host", "", "ANALYTICDB_HOST"), + ("port", 0, "ANALYTICDB_PORT"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"), + ("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"), + ], +) +def test_sql_config_required_fields(field, value, error_message): + values = _config_values() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_sql_config_rejects_min_connection_greater_than_max_connection(): + values = _config_values() + values["min_connection"] = 10 + values["max_connection"] = 2 + + with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_initialize_skips_when_cache_exists(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + sql_module.redis_client.set.assert_called_once() + + +def test_create_connection_pool_uses_psycopg2_pool(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + pool_instance = MagicMock() + monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance)) + + pool = vector._create_connection_pool() + + assert pool is pool_instance + sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once() + + +def test_get_cursor_context_manager_handles_connection_lifecycle(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + cursor = MagicMock() + connection = MagicMock() + connection.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = connection + vector.pool = pool + + with vector._get_cursor() as cur: + assert cur is cursor + + cursor.close.assert_called_once() + connection.commit.assert_called_once() + pool.putconn.assert_called_once_with(connection) + + +def test_add_texts_inserts_only_documents_with_metadata(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id") + monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock()) + + docs = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]]) + + execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args + assert execute_args[0] is cursor + assert len(execute_args[2]) == 1 + + +def test_text_exists_returns_true_and_false_based_on_query_result(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.fetchone.return_value = ("row",) + assert vector.text_exists("d1") is True + + cursor.fetchone.return_value = None + assert vector.text_exists("d1") is False + + +def test_delete_by_ids_handles_empty_input_and_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_ids(["d1"]) + + +def test_delete_by_metadata_field_handles_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_metadata_field("document_id", "doc-1") + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_vector_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k) + + +def test_search_by_vector_returns_documents_above_threshold(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}), + ("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content 1" + assert docs[0].metadata["score"] == 0.8 + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_full_text_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=invalid_top_k) + + +def test_search_by_full_text_returns_documents(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + assert docs[0].page_content == "content 1" + + +def test_delete_drops_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete() + + cursor.execute.assert_called_once() + + +def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch): + config = AnalyticdbVectorBySqlConfig(**_config_values()) + created_pool = MagicMock() + + monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock()) + monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool)) + + vector = AnalyticdbVectorBySql("My_Collection", config) + + assert vector._collection_name == "my_collection" + assert vector.table_name == "dify.my_collection" + assert vector.databaseName == "knowledgebase" + assert vector.pool is created_pool + + +def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + bootstrap_cursor.execute.side_effect = RuntimeError("database already exists") + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + + def _execute(sql, *args, **kwargs): + if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql: + raise RuntimeError("already exists") + + worker_cursor.execute.side_effect = _execute + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + vector._initialize_vector_database() + + bootstrap_cursor.close.assert_called_once() + bootstrap_connection.close.assert_called_once() + vector._create_connection_pool.assert_called_once() + assert any( + "CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0] + for call in worker_cursor.execute.call_args_list + ) + assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list) + + +def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable") + + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + with pytest.raises(RuntimeError, match="Failed to create zhparser extension"): + vector._initialize_vector_database() + + worker_connection.rollback.assert_called_once() + + +def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + vector._create_collection_if_not_exists(embedding_dimension=3) + + assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) + assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) + sql_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + cursor.execute.side_effect = RuntimeError("permission denied") + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + with pytest.raises(RuntimeError, match="permission denied"): + vector._create_collection_if_not_exists(embedding_dimension=3) + + +def test_delete_methods_raise_when_error_is_not_missing_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.execute.side_effect = RuntimeError("unexpected delete failure") + with pytest.raises(RuntimeError, match="unexpected delete failure"): + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("unexpected metadata failure") + with pytest.raises(RuntimeError, match="unexpected metadata failure"): + vector.delete_by_metadata_field("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py new file mode 100644 index 0000000000..c46c3d5e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -0,0 +1,542 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_pymochow_modules(): + pymochow = types.ModuleType("pymochow") + pymochow.__path__ = [] + pymochow_auth = types.ModuleType("pymochow.auth") + pymochow_auth.__path__ = [] + pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials") + pymochow_configuration = types.ModuleType("pymochow.configuration") + pymochow_exception = types.ModuleType("pymochow.exception") + pymochow_model = types.ModuleType("pymochow.model") + pymochow_model.__path__ = [] + pymochow_model_database = types.ModuleType("pymochow.model.database") + pymochow_model_enum = types.ModuleType("pymochow.model.enum") + pymochow_model_schema = types.ModuleType("pymochow.model.schema") + pymochow_model_table = types.ModuleType("pymochow.model.table") + + class _SimpleObject: + def __init__(self, *args, **kwargs): + self.args = args + for key, value in kwargs.items(): + setattr(self, key, value) + + class ServerError(Exception): + def __init__(self, code): + super().__init__(f"server error {code}") + self.code = code + + class ServerErrCode: + TABLE_NOT_EXIST = 1001 + DB_ALREADY_EXIST = 1002 + + class IndexType: + __members__ = {"HNSW": "HNSW"} + + class MetricType: + __members__ = {"IP": "IP"} + + class IndexState: + NORMAL = "NORMAL" + + class TableState: + NORMAL = "NORMAL" + + class InvertedIndexAnalyzer: + DEFAULT_ANALYZER = "DEFAULT_ANALYZER" + + class InvertedIndexParseMode: + COARSE_MODE = "COARSE_MODE" + + class InvertedIndexFieldAttribute: + ANALYZED = "ANALYZED" + + class FieldType: + STRING = "STRING" + TEXT = "TEXT" + JSON = "JSON" + FLOAT_VECTOR = "FLOAT_VECTOR" + + pymochow.MochowClient = _SimpleObject + pymochow_credentials.BceCredentials = _SimpleObject + pymochow_configuration.Configuration = _SimpleObject + pymochow_exception.ServerError = ServerError + pymochow_model_database.Database = _SimpleObject + + pymochow_model_enum.FieldType = FieldType + pymochow_model_enum.IndexState = IndexState + pymochow_model_enum.IndexType = IndexType + pymochow_model_enum.MetricType = MetricType + pymochow_model_enum.ServerErrCode = ServerErrCode + pymochow_model_enum.TableState = TableState + + for cls_name in [ + "AutoBuildRowCountIncrement", + "Field", + "FilteringIndex", + "HNSWParams", + "InvertedIndex", + "InvertedIndexParams", + "Schema", + "VectorIndex", + ]: + setattr(pymochow_model_schema, cls_name, _SimpleObject) + pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer + pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute + pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode + + for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]: + setattr(pymochow_model_table, cls_name, _SimpleObject) + + pymochow.auth = pymochow_auth + pymochow.model = pymochow_model + pymochow_auth.bce_credentials = pymochow_credentials + pymochow_model.database = pymochow_model_database + pymochow_model.enum = pymochow_model_enum + pymochow_model.schema = pymochow_model_schema + pymochow_model.table = pymochow_model_table + + modules = { + "pymochow": pymochow, + "pymochow.auth": pymochow_auth, + "pymochow.auth.bce_credentials": pymochow_credentials, + "pymochow.configuration": pymochow_configuration, + "pymochow.exception": pymochow_exception, + "pymochow.model": pymochow_model, + "pymochow.model.database": pymochow_model_database, + "pymochow.model.enum": pymochow_model_enum, + "pymochow.model.schema": pymochow_model_schema, + "pymochow.model.table": pymochow_model_table, + } + return modules + + +@pytest.fixture +def baidu_module(monkeypatch): + for name, module in _build_fake_pymochow_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + import core.rag.datasource.vdb.baidu.baidu_vector as module + + return importlib.reload(module) + + +def test_baidu_config_validation(baidu_module): + values = { + "endpoint": "https://example.com", + "account": "account", + "api_key": "key", + "database": "database", + } + config = baidu_module.BaiduConfig.model_validate(values) + assert config.endpoint == "https://example.com" + + for key, error_message in [ + ("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"), + ("account", "BAIDU_VECTOR_DB_ACCOUNT"), + ("api_key", "BAIDU_VECTOR_DB_API_KEY"), + ("database", "BAIDU_VECTOR_DB_DATABASE"), + ]: + invalid = dict(values) + invalid[key] = "" + with pytest.raises(ValueError, match=error_message): + baidu_module.BaiduConfig.model_validate(invalid) + + +def test_get_search_result_handles_metadata_and_threshold(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace( + rows=[ + {"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9}, + {"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4}, + {"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95}, + ] + ) + + docs = vector._get_search_res(response, score_threshold=0.8) + + assert len(docs) == 2 + assert docs[0].page_content == "doc1" + assert docs[0].metadata["score"] == 0.9 + assert docs[1].page_content == "doc3" + + +def test_delete_by_ids_and_delete_by_metadata_field(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + vector._collection_name = "collection_1" + + vector.delete_by_ids([]) + table.delete.assert_not_called() + + vector.delete_by_ids(["id1", "id2"]) + table.delete.assert_called_once() + + table.delete.reset_mock() + vector.delete_by_metadata_field("source", 'abc"def') + delete_filter = table.delete.call_args.kwargs["filter"] + assert delete_filter == 'metadata["source"] = "abc\\"def"' + + +def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + + vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + vector.delete() + + vector._db.drop_table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector.delete() + + +def test_init_database_uses_existing_or_creates_when_missing(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + + vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")] + vector._client.database.return_value = "existing_db" + assert vector._init_database() == "existing_db" + + vector._client.list_databases.return_value = [] + vector._client.database.return_value = "created_db" + vector._client.create_database.side_effect = None + assert vector._init_database() == "created_db" + + vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST) + assert vector._init_database() == "created_db" + + +def test_table_existed_checks_table_access(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._db.table.return_value = MagicMock() + + assert vector._table_existed() is True + + vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + assert vector._table_existed() is False + + vector._db.table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector._table_existed() + + +def test_search_methods_delegate_to_database_table(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})]) + + table = MagicMock() + vector._db.table.return_value = table + table.search.return_value = "vector_result" + table.bm25_search.return_value = "bm25_result" + + result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + + assert result1 == vector._get_search_res.return_value + assert result2 == vector._get_search_res.return_value + assert vector._get_search_res.call_count == 2 + + +def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection" + assert dataset.index_struct is not None + + +def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch): + init_client = MagicMock(return_value="client") + init_database = MagicMock(return_value="database") + monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client) + monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database) + + config = baidu_module.BaiduConfig( + endpoint="https://example.com", + account="account", + api_key="key", + database="db", + ) + vector = baidu_module.BaiduVector(collection_name="my_collection", config=config) + + assert vector.get_type() == baidu_module.VectorType.BAIDU + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection" + assert vector._client == "client" + assert vector._db == "database" + + vector._create_table = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="p1", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_table.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_rows(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]]) + + assert table.upsert.call_count == 1 + inserted_rows = table.upsert.call_args.kwargs["rows"] + assert len(inserted_rows) == 2 + + +def test_add_texts_batches_more_than_batch_size(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"}) + for idx in range(1001) + ] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + vector.add_texts(docs, embeddings) + + assert table.upsert.call_count == 2 + assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000 + assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1 + + +def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + table.query.return_value = SimpleNamespace(code=0) + assert vector.text_exists("id-1") is True + + table.query.return_value = SimpleNamespace(code=1) + assert vector.text_exists("id-1") is False + + table.query.return_value = None + assert vector.text_exists("id-1") is False + + +def test_get_search_result_handles_invalid_metadata_json(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}]) + + docs = vector._get_search_res(response, score_threshold=0.1) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.7 + assert "document_id" not in docs[0].metadata + + +def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch): + credentials = MagicMock(return_value="credentials") + configuration = MagicMock(return_value="configuration") + client_cls = MagicMock(return_value="client") + monkeypatch.setattr(baidu_module, "BceCredentials", credentials) + monkeypatch.setattr(baidu_module, "Configuration", configuration) + monkeypatch.setattr(baidu_module, "MochowClient", client_cls) + + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + + client = vector._init_client(config) + + assert client == "client" + credentials.assert_called_once_with("account", "key") + configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + client_cls.assert_called_once_with("configuration") + + +def test_init_database_raises_for_unknown_create_database_error(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + vector._client.list_databases.return_value = [] + vector._client.create_database.side_effect = baidu_module.ServerError(9999) + + with pytest.raises(baidu_module.ServerError): + vector._init_database() + + +def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + table = MagicMock() + table.state = baidu_module.TableState.NORMAL + vector._db.describe_table.return_value = table + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock()) + + # Cached table skips all work. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Existing table also skips creation. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + vector._table_existed.return_value = True + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Create table when cache is empty and table does not exist. + vector._table_existed.return_value = False + vector._create_table(3) + vector._db.create_table.assert_called_once() + baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600) + table.rebuild_index.assert_called_once_with(vector.vector_index) + vector._wait_for_index_ready.assert_called_once_with(table, 3600) + + +def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + vector._client_config = SimpleNamespace( + index_type="INVALID", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_table(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "INVALID" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_table(3) + + +def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + vector._db.describe_table.return_value = SimpleNamespace(state="CREATING") + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301])) + + with pytest.raises(TimeoutError, match="Table creation timeout"): + vector._create_table(3) + + +def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py new file mode 100644 index 0000000000..44427b7d87 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py @@ -0,0 +1,199 @@ +import importlib +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_chroma_modules(): + chroma = types.ModuleType("chromadb") + chroma.DEFAULT_TENANT = "default_tenant" + chroma.DEFAULT_DATABASE = "default_database" + + class Settings: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class QueryResult(UserDict): + pass + + class _Collection: + def __init__(self): + self.upsert = MagicMock() + self.delete = MagicMock() + self.query = MagicMock() + self.get = MagicMock(return_value={}) + + class _Client: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.collection = _Collection() + self.get_or_create_collection = MagicMock(return_value=self.collection) + self.delete_collection = MagicMock() + + chroma.Settings = Settings + chroma.QueryResult = QueryResult + chroma.HttpClient = _Client + return chroma + + +@pytest.fixture +def chroma_module(monkeypatch): + fake_chroma = _build_fake_chroma_modules() + monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) + import core.rag.datasource.vdb.chroma.chroma_vector as module + + return importlib.reload(module) + + +def test_chroma_config_to_params_builds_expected_payload(chroma_module): + config = chroma_module.ChromaConfig( + host="localhost", + port=8000, + tenant="tenant-1", + database="db-1", + auth_provider="provider", + auth_credentials="credentials", + ) + + params = config.to_chroma_params() + + assert params["host"] == "localhost" + assert params["port"] == 8000 + assert params["tenant"] == "tenant-1" + assert params["database"] == "db-1" + assert params["ssl"] is False + assert params["settings"].chroma_client_auth_provider == "provider" + assert params["settings"].chroma_client_auth_credentials == "credentials" + + +def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock()) + + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create_collection("collection_1") + + vector._client.get_or_create_collection.assert_called_once_with("collection_1") + chroma_module.redis_client.set.assert_called_once() + + +def test_create_with_empty_texts_is_noop(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create([], []) + vector._client.get_or_create_collection.assert_not_called() + + +def test_create_with_texts_creates_collection_and_upserts(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._client.get_or_create_collection.assert_called() + vector._client.collection.upsert.assert_called_once() + + +def test_delete_methods_and_text_exists(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + + vector.delete_by_ids([]) + vector._client.collection.delete.assert_not_called() + + vector.delete_by_ids(["id-1"]) + vector._client.collection.delete.assert_called_with(ids=["id-1"]) + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}}) + + vector._client.collection.get.return_value = {"ids": ["id-1"]} + assert vector.text_exists("id-1") is True + vector._client.collection.get.return_value = {} + assert vector.text_exists("id-2") is False + + vector.delete() + vector._client.delete_collection.assert_called_once_with("collection_1") + + +def test_search_by_vector_handles_empty_results(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []} + + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + +def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = { + "ids": [["id-1", "id-2"]], + "documents": [["doc high", "doc low"]], + "metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]], + "distances": [[0.1, 0.8]], + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc high" + assert docs[0].metadata["score"] == 0.9 + + +def test_search_by_full_text_returns_empty_list(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + assert vector.search_by_full_text("query") == [] + + +def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch): + factory = chroma_module.ChromaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None) + + with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py new file mode 100644 index 0000000000..0ce5c04dd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py @@ -0,0 +1,927 @@ +import importlib +import queue +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickzetta_module(): + clickzetta = types.ModuleType("clickzetta") + + class _FakeCursor: + def __init__(self): + self.execute = MagicMock() + self.executemany = MagicMock() + self.fetchall = MagicMock(return_value=[]) + self.fetchone = MagicMock(return_value=(0,)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _FakeConnection: + def __init__(self): + self.cursor_obj = _FakeCursor() + + def cursor(self): + return self.cursor_obj + + def close(self): + return None + + def connect(**_kwargs): + return _FakeConnection() + + clickzetta.connect = connect + return clickzetta + + +@pytest.fixture +def clickzetta_module(monkeypatch): + monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) + import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.ClickzettaConfig( + username="username", + password="password", + instance="instance", + service="service", + workspace="workspace", + vcluster="cluster", + schema_name="dify", + ) + + +@pytest.mark.parametrize( + ("field", "error_message"), + [ + ("username", "CLICKZETTA_USERNAME"), + ("password", "CLICKZETTA_PASSWORD"), + ("instance", "CLICKZETTA_INSTANCE"), + ("service", "CLICKZETTA_SERVICE"), + ("workspace", "CLICKZETTA_WORKSPACE"), + ("vcluster", "CLICKZETTA_VCLUSTER"), + ("schema_name", "CLICKZETTA_SCHEMA"), + ], +) +def test_clickzetta_config_validation(clickzetta_module, field, error_message): + values = _config(clickzetta_module).model_dump() + values[field] = "" + with pytest.raises(ValueError, match=error_message): + clickzetta_module.ClickzettaConfig.model_validate(values) + + +def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1") + assert parsed["doc_id"] == "row-1" + assert parsed["document_id"] == "doc-1" + + parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2") + assert parsed_double["doc_id"] == "row-2" + assert parsed_double["document_id"] == "doc-2" + + parsed_fallback = vector._parse_metadata("not-json", "row-3") + assert parsed_fallback["doc_id"] == "row-3" + assert parsed_fallback["document_id"] == "row-3" + + +def test_safe_doc_id_and_vector_format_helpers(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3" + assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF" + assert vector._safe_doc_id("ab c;\n") == "abc" + assert len(vector._safe_doc_id("a" * 300)) == 255 + + +def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_not_found(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_not_found + assert vector._table_exists() is False + + @contextmanager + def _ctx_other_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("permission denied") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_other_error + assert vector._table_exists() is False + + +def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.text_exists("doc-1") is False + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchone.return_value = (1,) + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + assert vector.text_exists("doc-1") is True + + +def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + + vector.delete_by_ids([]) + vector._execute_write.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + vector.delete_by_ids(["doc-1"]) + vector._execute_write.assert_not_called() + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_not_called() + + +def test_search_short_circuit_behaviors(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + vector._config.enable_inverted_index = False + assert vector.search_by_full_text("query", top_k=2) == [] + + +def test_search_by_like_returns_documents_with_default_score(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content" + assert docs[0].metadata["score"] == 0.5 + + +def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch): + factory = clickzetta_module.ClickzettaVectorFactory() + dataset = SimpleNamespace(id="dataset-1") + + monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance") + + with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "collection" + + +def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch): + clickzetta_module.ClickzettaConnectionPool._instance = None + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + + pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance() + pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance() + key = pool_1._get_config_key(_config(clickzetta_module)) + + assert pool_1 is pool_2 + assert "username:instance:service:workspace:cluster:dify" in key + + +def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + connection = MagicMock() + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr( + clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection]) + ) + pool._configure_connection = MagicMock() + + created = pool._create_connection(config) + + assert created is connection + assert clickzetta_module.clickzetta.connect.call_count == 2 + pool._configure_connection.assert_called_once_with(connection) + + +def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + pool._create_connection(config) + + +def test_connection_pool_configure_and_validate_connection(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection = MagicMock() + connection.cursor.return_value = cursor + + pool._configure_connection(connection) + assert cursor.execute.call_count >= 2 + assert pool._is_connection_valid(connection) is True + + bad_connection = MagicMock() + bad_connection.cursor.side_effect = RuntimeError("bad connection") + assert pool._is_connection_valid(bad_connection) is False + monkeypatch.undo() + + +def test_connection_pool_configure_connection_swallows_errors(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + connection = MagicMock() + connection.cursor.side_effect = RuntimeError("cannot configure") + + pool._configure_connection(connection) + monkeypatch.undo() + + +def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + key = pool._get_config_key(config) + + created_connection = MagicMock() + pool._create_connection = MagicMock(return_value=created_connection) + first = pool.get_connection(config) + assert first is created_connection + + reusable_connection = MagicMock() + pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())] + pool._is_connection_valid = MagicMock(return_value=True) + reused = pool.get_connection(config) + assert reused is reusable_connection + + expired_connection = MagicMock() + pool._pools[key] = [(expired_connection, 0.0)] + pool._is_connection_valid = MagicMock(return_value=False) + monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0)) + pool.get_connection(config) + expired_connection.close.assert_called_once() + + random_connection = MagicMock() + pool._is_connection_valid = MagicMock(return_value=True) + pool.return_connection(config, random_connection) + assert len(pool._pools[key]) == 1 + + pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)] + pool._connection_timeout = 10 + pool._cleanup_expired_connections() + assert len(pool._pools[key]) == 1 + + unknown_pool = MagicMock() + pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool) + unknown_pool.close.assert_called_once() + + pool.shutdown() + assert pool._shutdown is True + + +def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True)) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + + assert pool._cleanup_thread.started is True + pool._cleanup_expired_connections.assert_called_once() + + +def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch): + pool = MagicMock() + pool.get_connection.return_value = "conn" + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool)) + monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock()) + + vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module)) + assert vector._table_name == "my_collection" + + assert vector._get_connection() == "conn" + vector._return_connection("conn") + pool.return_connection.assert_called_with(vector._config, "conn") + + with vector.get_connection_context() as conn: + assert conn == "conn" + assert pool.return_connection.call_count >= 2 + + assert vector.get_type() == "clickzetta" + assert vector._ensure_connection() == "conn" + + +def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch): + class _Thread: + def __init__(self, target, daemon): + self.target = target + self.daemon = daemon + self.started = 0 + + def start(self): + self.started += 1 + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_thread = None + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._init_write_queue() + clickzetta_module.ClickzettaVector._init_write_queue() + assert clickzetta_module.ClickzettaVector._write_thread.started == 1 + + result_queue_ok = queue.Queue() + result_queue_fail = queue.Queue() + clickzetta_module.ClickzettaVector._write_queue = queue.Queue() + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok)) + clickzetta_module.ClickzettaVector._write_queue.put( + (lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail) + ) + clickzetta_module.ClickzettaVector._write_queue.put(None) + clickzetta_module.ClickzettaVector._write_worker() + + assert result_queue_ok.get() == (True, 2) + failed = result_queue_fail.get() + assert failed[0] is False + assert isinstance(failed[1], RuntimeError) + + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + clickzetta_module.ClickzettaVector._write_queue = None + with pytest.raises(RuntimeError, match="Write queue not initialized"): + vector._execute_write(lambda: None) + + class _ImmediateSuccessQueue: + def put(self, task): + func, args, kwargs, result_q = task + result_q.put((True, func(*args, **kwargs))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue() + assert vector._execute_write(lambda x: x * 2, 3) == 6 + + class _ImmediateFailQueue: + def put(self, task): + _, _, _, result_q = task + result_q.put((False, ValueError("write failed"))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue() + with pytest.raises(ValueError, match="write failed"): + vector._execute_write(lambda: None) + + +def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_exists(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_exists + assert vector._table_exists() is True + + vector._execute_write = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="content", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._execute_write.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_table_and_indexes_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._create_vector_index = MagicMock() + vector._create_inverted_index = MagicMock() + + vector._table_exists = MagicMock(return_value=True) + vector._create_table_and_indexes([[0.1, 0.2]]) + vector._create_vector_index.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._create_table_and_indexes([[0.1, 0.2, 0.3]]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_called_once() + + vector._config.enable_inverted_index = False + vector._create_vector_index.reset_mock() + vector._create_inverted_index.reset_mock() + vector._create_table_and_indexes([]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_not_called() + + +def test_create_vector_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show index failed"), None] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("already exists")] + cursor.fetchall.return_value = [] + vector._create_vector_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("unexpected")] + cursor.fetchall.return_value = [] + with pytest.raises(RuntimeError, match="unexpected"): + vector._create_vector_index(cursor) + + +def test_create_inverted_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show failed"), None] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [ + None, + RuntimeError("already has index"), + None, + ] + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("other create failure")] + cursor.fetchall.return_value = [] + vector._create_inverted_index(cursor) + + +def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.batch_size = 2 + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id)) + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2"}), + Document(page_content="doc-3", metadata={"doc_id": "id-3"}), + ] + vectors = [[0.1], [0.2], [0.3]] + + vector.add_texts([], []) + vector._execute_write.assert_not_called() + + added_ids = vector.add_texts(docs, vectors) + assert added_ids == ["id-1", "id-2", "id-3"] + assert vector._execute_write.call_count == 2 + assert vector._execute_write.call_args_list[0].args == ( + vector._insert_batch, + docs[:2], + vectors[:2], + ["id-1", "id-2"], + 0, + 2, + 2, + ) + assert vector._execute_write.call_args_list[1].args == ( + vector._insert_batch, + docs[2:], + vectors[2:], + ["id-3"], + 2, + 2, + 2, + ) + + vector._insert_batch([], [], [], 0, 2, 1) + vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1) + + bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}}) + good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [bad_doc, good_doc], + [[0.1, 0.2], [0.3, 0.4]], + ["id-bad", "id-good"], + 0, + 2, + 1, + ) + + @contextmanager + def _ctx_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.executemany.side_effect = RuntimeError("insert failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_error + with pytest.raises(RuntimeError, match="insert failed"): + vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1) + + monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id") + vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector) + assert vector._safe_doc_id("") == "generated-id" + assert vector._safe_doc_id("!!!") == "generated-id" + + +def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._table_exists = MagicMock(return_value=True) + + vector.delete_by_ids(["id-1", "id-2"]) + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl + + vector._execute_write.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl + + vector._safe_doc_id = MagicMock(side_effect=lambda x: x) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._delete_by_ids_impl(["id-1", "id-2"]) + vector._delete_by_metadata_field_impl("document_id", "doc-1") + + +def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.vector_distance_function = "cosine_distance" + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + cosine_docs = vector.search_by_vector( + [0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"} + ) + assert cosine_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._config.vector_distance_function = "l2_distance" + l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2) + + +def test_search_by_full_text_success_and_fallback(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx_success(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'), + ("seg-2", "content-2", "invalid-json"), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_success + docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1}) + assert len(docs) == 2 + assert docs[0].metadata["score"] == 1.0 + assert docs[1].metadata["doc_id"] == "seg-2" + + @contextmanager + def _ctx_failure(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("full text failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_failure + vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})]) + fallback_docs = vector.search_by_full_text("query", top_k=1) + assert fallback_docs == vector._search_by_like.return_value + + +def test_search_by_like_missing_table_and_delete_table(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=False) + assert vector._search_by_like("query", top_k=1) == [] + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector.delete() + + +def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._pools = {} + pool._pool_locks = {} + pool._max_pool_size = 1 + pool._connection_timeout = 10 + pool._lock = clickzetta_module.threading.Lock() + pool._shutdown = False + + config = _config(clickzetta_module) + key = pool._get_config_key(config) + pool._pools[key] = [(MagicMock(), 1.0)] + pool._pool_locks[key] = clickzetta_module.threading.Lock() + pool._is_connection_valid = MagicMock(return_value=False) + + conn = MagicMock() + pool.return_connection(config, conn) + conn.close.assert_called_once() + + pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)] + pool._cleanup_expired_connections() + pool.shutdown() + assert pool._shutdown is True + + +def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + + def _cleanup_then_fail(): + pool._shutdown = True + raise RuntimeError("cleanup failed") + + pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail) + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + + def start(self): + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + pool._cleanup_expired_connections.assert_called_once() + + +def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1") + assert parsed_non_dict["doc_id"] == "row-1" + assert parsed_non_dict["document_id"] == "row-1" + + parsed_none = vector._parse_metadata(None, "row-2") + assert parsed_none["doc_id"] == "row-2" + assert parsed_none["document_id"] == "row-2" + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_worker() + + class _BadQueue: + def get(self, timeout): + clickzetta_module.ClickzettaVector._shutdown = True + raise RuntimeError("queue failed") + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = _BadQueue() + clickzetta_module.ClickzettaVector._write_worker() + + +def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + cursor.execute.side_effect = [ + None, + RuntimeError("already has index with the same type cannot create inverted index"), + None, + ] + + vector._create_inverted_index(cursor) + + vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value)) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor_obj = MagicMock() + cursor_obj.__enter__.return_value = cursor_obj + cursor_obj.__exit__.return_value = None + connection.cursor.return_value = cursor_obj + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [SimpleNamespace(page_content="content", metadata="not-a-dict")], + [[0.1, 0.2]], + ["doc-1"], + 0, + 1, + 1, + ) + + +def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.enable_inverted_index = True + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_full_text("query") == [] + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", "[1,2,3]"), + ("seg-2", "content-2", None), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector.search_by_full_text("query") + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[1].metadata["doc_id"] == "seg-2" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py new file mode 100644 index 0000000000..9fea187615 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py @@ -0,0 +1,364 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_couchbase_modules(): + couchbase = types.ModuleType("couchbase") + couchbase_auth = types.ModuleType("couchbase.auth") + couchbase_cluster = types.ModuleType("couchbase.cluster") + couchbase_management = types.ModuleType("couchbase.management") + couchbase_management_search = types.ModuleType("couchbase.management.search") + couchbase_options = types.ModuleType("couchbase.options") + couchbase_vector = types.ModuleType("couchbase.vector_search") + couchbase_search = types.ModuleType("couchbase.search") + + class PasswordAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + + class ClusterOptions: + def __init__(self, auth): + self.auth = auth + + class SearchOptions: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorQuery: + def __init__(self, field, vector, top_k): + self.field = field + self.vector = vector + self.top_k = top_k + + class VectorSearch: + @staticmethod + def from_vector_query(vector_query): + return {"vector_query": vector_query} + + class QueryStringQuery: + def __init__(self, query): + self.query = query + + class SearchRequest: + @staticmethod + def create(payload): + return {"payload": payload} + + class SearchIndex: + def __init__(self, name, params, source_name): + self.name = name + self.params = params + self.source_name = source_name + + class _QueryResult: + def __init__(self, rows=None): + self._rows = rows or [] + + def execute(self): + return self + + def __iter__(self): + return iter(self._rows) + + class _SearchIter: + def __init__(self, rows=None): + self._rows = rows or [] + + def rows(self): + return self._rows + + class _Collection: + def __init__(self): + self.upsert = MagicMock(return_value=True) + + class _SearchIndexManager: + def __init__(self): + self.upsert_index = MagicMock() + + class _Scope: + def __init__(self): + self._collection = _Collection() + self._search_index_manager = _SearchIndexManager() + self.search = MagicMock(return_value=_SearchIter()) + + def collection(self, _name): + return self._collection + + def search_indexes(self): + return self._search_index_manager + + class _CollectionManager: + def __init__(self): + self.create_collection = MagicMock() + self.drop_collection = MagicMock() + self.get_all_scopes = MagicMock(return_value=[]) + + class _Bucket: + def __init__(self): + self._scope = _Scope() + self._collections = _CollectionManager() + + def scope(self, _scope_name): + return self._scope + + def collections(self): + return self._collections + + class Cluster: + def __init__(self, connection_string, options): + self.connection_string = connection_string + self.options = options + self._bucket = _Bucket() + self.wait_until_ready = MagicMock() + self.query = MagicMock(return_value=_QueryResult()) + + def bucket(self, _name): + return self._bucket + + couchbase_auth.PasswordAuthenticator = PasswordAuthenticator + couchbase_cluster.Cluster = Cluster + couchbase_management_search.SearchIndex = SearchIndex + couchbase_options.ClusterOptions = ClusterOptions + couchbase_options.SearchOptions = SearchOptions + couchbase_vector.VectorQuery = VectorQuery + couchbase_vector.VectorSearch = VectorSearch + couchbase_search.QueryStringQuery = QueryStringQuery + couchbase_search.SearchRequest = SearchRequest + + couchbase.search = couchbase_search + couchbase.management = couchbase_management + + return { + "couchbase": couchbase, + "couchbase.auth": couchbase_auth, + "couchbase.cluster": couchbase_cluster, + "couchbase.management": couchbase_management, + "couchbase.management.search": couchbase_management_search, + "couchbase.options": couchbase_options, + "couchbase.vector_search": couchbase_vector, + "couchbase.search": couchbase_search, + } + + +@pytest.fixture +def couchbase_module(monkeypatch): + for name, module in _build_fake_couchbase_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.couchbase.couchbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.CouchbaseConfig( + connection_string="couchbase://localhost", + user="user", + password="pass", + bucket_name="bucket", + scope_name="scope", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("connection_string", "", "CONNECTION_STRING is required"), + ("user", "", "COUCHBASE_USER is required"), + ("password", "", "COUCHBASE_PASSWORD is required"), + ("bucket_name", "", "COUCHBASE_PASSWORD is required"), + ("scope_name", "", "COUCHBASE_SCOPE_NAME is required"), + ], +) +def test_couchbase_config_validation(couchbase_module, field, value, message): + values = _config(couchbase_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + couchbase_module.CouchbaseConfig.model_validate(values) + + +def test_init_sets_cluster_handles(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + assert vector._bucket_name == "bucket" + assert vector._scope_name == "scope" + vector._cluster.wait_until_ready.assert_called_once() + + +def test_create_and_create_collection_branches(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector) + vector._collection_name = "collection_1" + vector._client_config = _config(couchbase_module) + vector._scope_name = "scope" + vector._bucket_name = "bucket" + vector._bucket = MagicMock() + vector._scope = MagicMock() + vector._collection_exists = MagicMock(return_value=False) + vector.add_texts = MagicMock() + + monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c") + vector._create_collection = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock()) + + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(vector_length=2, uuid="uuid-1") + vector._bucket.collections().create_collection.assert_not_called() + + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._collection_exists = MagicMock(return_value=True) + vector._create_collection(vector_length=2, uuid="uuid-2") + vector._bucket.collections().create_collection.assert_not_called() + + vector._collection_exists = MagicMock(return_value=False) + vector._create_collection(vector_length=3, uuid="uuid-3") + + vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1") + vector._scope.search_indexes().upsert_index.assert_called_once() + search_index = vector._scope.search_indexes().upsert_index.call_args.args[0] + assert search_index.name == "collection_1_search" + assert ( + search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"] + == 3 + ) + couchbase_module.redis_client.set.assert_called_once() + + +def test_collection_exists_get_type_and_add_texts(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is True + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is False + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert vector._scope.collection("collection_1").upsert.call_count == 2 + assert vector.get_type() == couchbase_module.VectorType.COUCHBASE + + +def test_query_delete_helpers(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}])) + assert vector.text_exists("id-1") is True + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([])) + assert vector.text_exists("id-2") is False + + query_result = MagicMock() + query_result.execute.return_value = None + vector._cluster.query.return_value = query_result + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_document_id("id-1") + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._cluster.query.call_count >= 3 + + vector._cluster.query.side_effect = RuntimeError("delete failed") + vector.delete_by_ids(["id-3"]) + + +def test_search_methods_and_format_metadata(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9) + row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["document_id"] == "d-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector._scope.search.side_effect = RuntimeError("search error") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_vector([0.1], top_k=1) + + vector._scope.search.side_effect = None + row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3]) + docs = vector.search_by_full_text("hello", top_k=1) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == "x" + + vector._scope.search.side_effect = RuntimeError("full text failed") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_full_text("hello", top_k=1) + + assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2} + + +def test_delete_collection_and_factory(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + scopes = [ + SimpleNamespace(collections=[SimpleNamespace(name="other")]), + SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]), + ] + vector._bucket.collections().get_all_scopes.return_value = scopes + + vector.delete() + vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1") + + factory = couchbase_module.CouchbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + couchbase_module, + "current_app", + SimpleNamespace( + config={ + "COUCHBASE_CONNECTION_STRING": "couchbase://localhost", + "COUCHBASE_USER": "user", + "COUCHBASE_PASSWORD": "pass", + "COUCHBASE_BUCKET_NAME": "bucket", + "COUCHBASE_SCOPE_NAME": "scope", + } + ), + ) + + with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py new file mode 100644 index 0000000000..edd29a4649 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py @@ -0,0 +1,121 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0"}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_ja_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + + importlib.reload(base_module) + return importlib.reload(ja_module) + + +def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_not_called() + + +def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2, 0.3]], [{}]) + + vector._client.indices.create.assert_called_once() + kwargs = vector._client.indices.create.call_args.kwargs + assert kwargs["index"] == "test" + assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3 + elasticsearch_ja_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_ja_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_called_once() + + +def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch): + factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + elasticsearch_ja_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_HOST": "localhost", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + + with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py new file mode 100644 index 0000000000..9ecf0caa24 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py @@ -0,0 +1,405 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}}) + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), + delete=MagicMock(), + exists=MagicMock(return_value=False), + create=MagicMock(), + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + + return importlib.reload(module) + + +def _regular_config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "username": "elastic", + "password": "secret", + "verify_certs": False, + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +def _cloud_config(module, **overrides): + values = { + "use_cloud": True, + "cloud_url": "https://cloud.example:9243", + "api_key": "api-key", + "verify_certs": True, + "ca_certs": "/tmp/ca.pem", + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("values", "message"), + [ + ({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"), + ({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"), + ({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"), + ({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"), + ({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"), + ({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"), + ], +) +def test_elasticsearch_config_validation(elasticsearch_module, values, message): + with pytest.raises(ValidationError, match=message): + elasticsearch_module.ElasticSearchConfig.model_validate(values) + + +def test_init_client_cloud_configuration(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + result = vector._init_client(_cloud_config(elasticsearch_module)) + + assert result is client + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://cloud.example:9243"] + assert kwargs["api_key"] == "api-key" + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + +def test_init_client_regular_https_and_http_fallback(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client( + _regular_config( + elasticsearch_module, + host="https://es.example", + port=9443, + verify_certs=True, + ca_certs="/tmp/ca.pem", + ) + ) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://es.example:9443"] + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200)) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["http://es.internal:9200"] + assert "verify_certs" not in kwargs + + +def test_init_client_connection_failures(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + + client = MagicMock() + client.ping.return_value = False + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client): + with pytest.raises(ConnectionError, match="Failed to connect"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object( + elasticsearch_module, + "Elasticsearch", + side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"), + ): + with pytest.raises(ConnectionError, match="Vector database connection error"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")): + with pytest.raises(ConnectionError, match="initialization failed"): + vector._init_client(_regular_config(elasticsearch_module)) + + +def test_init_get_version_and_check_version(elasticsearch_module): + with ( + patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client, + patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version, + patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version, + ): + vector = elasticsearch_module.ElasticSearchVector( + "collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"] + ) + + init_client.assert_called_once() + get_version.assert_called_once() + check_version.assert_called_once() + assert vector._attributes == ["doc_id"] + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._client = MagicMock() + vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}} + assert vector._get_version() == "8.13.2" + + vector._version = "7.17.0" + with pytest.raises(ValueError, match="greater than 8.0.0"): + vector._check_version() + + vector._version = "8.0.0" + vector._check_version() + + +def test_crud_methods_and_get_type(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock()) + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection_1") + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._client.delete.call_count == 2 + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "d1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "d2") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1") + assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH + + +def test_search_by_vector_and_full_text(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.8, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-a", + elasticsearch_module.Field.VECTOR: [0.1], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-b", + elasticsearch_module.Field.VECTOR: [0.2], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + knn = vector._client.search.call_args.kwargs["knn"] + assert knn["k"] == 2 + assert knn["num_candidates"] == 3 + assert "filter" in knn + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "text-hit", + elasticsearch_module.Field.VECTOR: [0.3], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + query = vector._client.search.call_args.kwargs["query"] + assert "bool" in query + + +def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + mappings = vector._client.indices.create.call_args.kwargs["mappings"] + assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2 + elasticsearch_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + elasticsearch_module.redis_client.set.assert_called_once() + + +def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch): + factory = elasticsearch_module.ElasticSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": False, + "ELASTICSEARCH_HOST": "es-host", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + "ELASTICSEARCH_VERIFY_CERTS": False, + } + ), + ) + + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + assert result_1 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION" + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic", + "ELASTICSEARCH_API_KEY": "api-key", + "ELASTICSEARCH_VERIFY_CERTS": True, + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + assert result_2 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is True + assert cfg.cloud_url == "https://cloud.elastic" + assert dataset_without_index.index_struct is not None + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": None, + "ELASTICSEARCH_HOST": "fallback-host", + "ELASTICSEARCH_PORT": 9201, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert cfg.host == "fallback-host" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py new file mode 100644 index 0000000000..5d9e744ded --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py @@ -0,0 +1,371 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_hologres_modules(): + holo_module = types.ModuleType("holo_search_sdk") + holo_types_module = types.ModuleType("holo_search_sdk.types") + + holo_types_module.BaseQuantizationType = str + holo_types_module.DistanceType = str + holo_types_module.TokenizerType = str + + def _connect(**kwargs): + client = MagicMock() + client.kwargs = kwargs + client.connect = MagicMock() + client.check_table_exist = MagicMock(return_value=False) + client.open_table = MagicMock(return_value=MagicMock()) + client.execute = MagicMock(return_value=[]) + client.drop_table = MagicMock() + return client + + holo_module.connect = MagicMock(side_effect=_connect) + + return { + "holo_search_sdk": holo_module, + "holo_search_sdk.types": holo_types_module, + } + + +@pytest.fixture +def hologres_module(monkeypatch): + for name, module in _build_fake_hologres_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.hologres.hologres_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.HologresVectorConfig( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema_name="public", + tokenizer="jieba", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config HOLOGRES_HOST is required"), + ("database", "", "config HOLOGRES_DATABASE is required"), + ("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"), + ], +) +def test_hologres_config_validation(hologres_module, field, value, message): + values = _valid_config(hologres_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + hologres_module.HologresVectorConfig.model_validate(values) + + +def test_init_client_and_get_type(hologres_module): + vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module)) + + hologres_module.holo.connect.assert_called_once_with( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema="public", + ) + vector._client.connect.assert_called_once() + assert vector.table_name == "embedding_collection_one" + assert vector.get_type() == hologres_module.VectorType.HOLOGRES + + +def test_create_delegates_collection_creation_and_upsert(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result is None + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_returns_empty_for_empty_documents(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + assert vector.add_texts([], []) == [] + vector._client.open_table.assert_not_called() + + +def test_add_texts_batches_and_serializes_metadata(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + table = vector._client.open_table.return_value + documents = [ + Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"}) + for i in range(100) + ] + documents.append(SimpleNamespace(page_content="doc-100", metadata=None)) + embeddings = [[float(i)] for i in range(len(documents))] + + ids = vector.add_texts(documents, embeddings) + + assert ids[:2] == ["id-0", "id-1"] + assert ids[-1] == "" + assert len(ids) == 101 + assert vector._client.open_table.call_count == 2 + assert table.upsert_multi.call_count == 2 + first_call = table.upsert_multi.call_args_list[0].kwargs + second_call = table.upsert_multi.call_args_list[1].kwargs + assert first_call["index_column"] == "id" + assert first_call["column_names"] == ["id", "text", "meta", "embedding"] + assert first_call["update_columns"] == ["text", "meta", "embedding"] + assert len(first_call["values"]) == 100 + assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"} + assert second_call["values"][0][0] == "" + assert second_call["values"][0][2] == "{}" + + +def test_text_exists_handles_missing_and_present_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + vector._client.execute.return_value = [(1,)] + + assert vector.text_exists("seg-1") is False + assert vector.text_exists("seg-1") is True + vector._client.execute.assert_called_once() + + +def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +def test_delete_by_ids_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + vector.delete_by_ids([]) + vector._client.check_table_exist.assert_not_called() + + vector._client.check_table_exist.return_value = False + vector.delete_by_ids(["id-1"]) + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_ids(["id-1", "id-2"]) + vector._client.execute.assert_called_once() + + +def test_delete_by_metadata_field_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_called_once() + + +def test_search_by_vector_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_vector([0.1, 0.2]) == [] + + +def test_search_by_vector_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + query = MagicMock() + table.search_vector.return_value = query + query.select.return_value = query + query.limit.return_value = query + query.where.return_value = query + query.fetchall.return_value = [ + (0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'), + (0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["doc-1"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + table.search_vector.assert_called_once() + query.where.assert_called_once() + + +def test_search_by_full_text_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_full_text("query") == [] + + +def test_search_by_full_text_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + search_query = MagicMock() + table.search_text.return_value = search_query + search_query.limit.return_value = search_query + search_query.where.return_value = search_query + search_query.fetchall.return_value = [ + ("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95), + ("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"]) + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.95) + assert docs[1].metadata["score"] == pytest.approx(0.7) + table.search_text.assert_called_once() + search_query.where.assert_called_once() + + +def test_delete_handles_existing_and_missing_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + + vector.delete() + vector._client.drop_table.assert_not_called() + + vector.delete() + vector._client.drop_table.assert_called_once_with(vector.table_name) + + +def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection(3) + + vector._client.check_table_exist.assert_not_called() + hologres_module.redis_client.set.assert_not_called() + + +def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, False, True] + table = vector._client.open_table.return_value + + vector._create_collection(3) + + vector._client.execute.assert_called_once() + table.set_vector_index.assert_called_once_with( + column="embedding", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + use_reorder=True, + ) + table.create_text_index.assert_called_once_with( + index_name="ft_idx_collection_one", + column="text", + tokenizer="jieba", + ) + hologres_module.redis_client.set.assert_called_once() + + +def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False] + [False] * 15 + + with pytest.raises(RuntimeError, match="was not ready after 30s"): + vector._create_collection(3) + + hologres_module.redis_client.set.assert_not_called() + + +def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch): + factory = hologres_module.HologresVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400) + + with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection" + generated_config = vector_cls.call_args_list[1].kwargs["config"] + assert generated_config.host == "127.0.0.1" + assert generated_config.database == "dify" + assert generated_config.access_key_id == "ak" + assert json.loads(dataset_without_index.index_struct) == { + "type": hologres_module.VectorType.HOLOGRES, + "vector_store": {"class_prefix": "generated_collection"}, + } diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py new file mode 100644 index 0000000000..9d23dfcf63 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py @@ -0,0 +1,243 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def huawei_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass") + + +def test_create_ssl_context(huawei_module): + ctx = huawei_module.create_ssl_context() + assert ctx.check_hostname is False + assert ctx.verify_mode == huawei_module.ssl.CERT_NONE + + +def test_huawei_config_validation_and_params(huawei_module): + with pytest.raises(ValidationError, match="HOSTS is required"): + huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""}) + + config = _config(huawei_module) + params = config.to_elasticsearch_params() + assert params["hosts"] == ["http://localhost:9200"] + assert params["basic_auth"] == ("user", "pass") + + config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None) + params = config.to_elasticsearch_params() + assert "basic_auth" not in params + + +def test_init_get_type_and_add_texts(huawei_module): + vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module)) + + assert vector._collection_name == "collection" + assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_crud_methods(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once_with(index="collection", id="id-1") + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection") + + +def test_search_by_vector_and_full_text(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + }, + { + "_score": 0.1, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-b", + huawei_module.Field.VECTOR: [0.2], + huawei_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + query_body = vector._client.search.call_args.kwargs["body"] + assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2 + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + huawei_module.Field.CONTENT_KEY: "text-hit", + huawei_module.Field.VECTOR: [0.3], + huawei_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + + +def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch): + class FakeDocument: + def __init__(self, page_content, vector, metadata): + self.page_content = page_content + self.vector = vector + self.metadata = None + + monkeypatch.setattr(huawei_module, "Document", FakeDocument) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + } + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5) + + assert docs == [] + + +def test_create_and_create_collection_paths(huawei_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock()) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + + kwargs = vector._client.indices.create.call_args.kwargs + mappings = kwargs["mappings"] + assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2 + assert kwargs["settings"] == {"index.vector": True} + huawei_module.redis_client.set.assert_called_once() + + +def test_huawei_factory_branches(huawei_module, monkeypatch): + factory = huawei_module.HuaweiCloudVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass") + + with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py new file mode 100644 index 0000000000..63338ca809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py @@ -0,0 +1,412 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_iris_module(): + iris = types.ModuleType("iris") + + def connect(**_kwargs): + conn = MagicMock() + conn.cursor.return_value = MagicMock() + return conn + + iris.connect = MagicMock(side_effect=connect) + return iris + + +@pytest.fixture +def iris_module(monkeypatch): + monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) + + import core.rag.datasource.vdb.iris.iris_vector as module + + reloaded = importlib.reload(module) + reloaded._pool_instance = None + return reloaded + + +def _config(module, **overrides): + values = { + "IRIS_HOST": "localhost", + "IRIS_SUPER_SERVER_PORT": 1972, + "IRIS_USER": "user", + "IRIS_PASSWORD": "pass", + "IRIS_DATABASE": "db", + "IRIS_SCHEMA": "schema", + "IRIS_CONNECTION_URL": "url", + "IRIS_MIN_CONNECTION": 1, + "IRIS_MAX_CONNECTION": 2, + "IRIS_TEXT_INDEX": True, + "IRIS_TEXT_INDEX_LANGUAGE": "en", + } + values.update(overrides) + return module.IrisVectorConfig.model_validate(values) + + +def test_get_iris_pool_singleton(iris_module): + iris_module._pool_instance = None + cfg = _config(iris_module) + + with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls: + pool_1 = iris_module.get_iris_pool(cfg) + pool_2 = iris_module.get_iris_pool(cfg) + + assert pool_1 == "pool" + assert pool_2 == "pool" + pool_cls.assert_called_once_with(cfg) + + +@pytest.fixture +def pool_with_min_max(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn: + pool = iris_module.IrisConnectionPool(cfg) + yield pool, create_conn + + +def test_pool_initialization_respects_min_max(pool_with_min_max): + pool, create_conn = pool_with_min_max + assert len(pool._pool) == 2 + assert create_conn.call_count == 2 + + +@pytest.fixture +def pool_for_get_connection(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_get_connection_returns_existing_and_increments(pool_for_get_connection): + pool = pool_for_get_connection + conn = MagicMock() + pool._pool = [conn] + pool._in_use = 0 + assert pool.get_connection() is conn + assert pool._in_use == 1 + + +def test_get_connection_creates_new_when_empty(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = 0 + pool._create_connection = MagicMock(return_value="new-conn") + assert pool.get_connection() == "new-conn" + + +def test_get_connection_raises_when_exhausted(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = pool._max_size + with pytest.raises(RuntimeError, match="exhausted"): + pool.get_connection() + + +@pytest.fixture +def pool_for_return_connection(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_return_connection_adds_healthy(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool.return_connection(conn) + assert pool._pool[-1] is conn + assert pool._in_use == 0 + + +def test_return_connection_replaces_bad(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + bad_conn = MagicMock() + bad_cursor = MagicMock() + bad_cursor.execute.side_effect = OSError("bad") + bad_conn.cursor.return_value = bad_cursor + replacement = MagicMock() + pool._create_connection = MagicMock(return_value=replacement) + pool.return_connection(bad_conn) + bad_conn.close.assert_called_once() + assert pool._pool[-1] is replacement + assert pool._in_use == 0 + + +def test_return_connection_ignores_none(pool_for_return_connection): + pool = pool_for_return_connection + before = len(pool._pool) + pool.return_connection(None) + assert len(pool._pool) == before + + +@pytest.fixture +def pool_for_schema_and_close(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool._pool = [conn] + return pool, conn, cursor + + +def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = {"cached_schema"} + pool.ensure_schema_exists("cached_schema") + cursor.execute.assert_not_called() + + +def test_ensure_schema_exists_creates_new(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (0,) + pool.ensure_schema_exists("new_schema") + assert "new_schema" in pool._schemas_initialized + assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list) + conn.commit.assert_called_once() + + +def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (1,) + pool.ensure_schema_exists("existing_schema") + conn.commit.assert_not_called() + + +def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.execute.side_effect = RuntimeError("schema failure") + with pytest.raises(RuntimeError, match="schema failure"): + pool.ensure_schema_exists("broken_schema") + conn.rollback.assert_called() + + +def test_close_all_closes_and_resets(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + conn_2 = MagicMock() + conn_2.close.side_effect = OSError("close fail") + pool._pool = [conn, conn_2] + pool._schemas_initialized = {"x"} + pool.close_all() + assert pool._pool == [] + assert pool._in_use == 0 + assert pool._schemas_initialized == set() + + +def test_iris_vector_init_get_cursor_and_create(iris_module): + pool = MagicMock() + pool.get_connection.return_value = MagicMock() + + with patch.object(iris_module, "get_iris_pool", return_value=pool): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + assert vector.table_name == "EMBEDDING_COLLECTION" + assert vector.schema == "schema" + assert vector.get_type() == iris_module.VectorType.IRIS + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + + with vector._get_cursor() as got_cursor: + assert got_cursor is cursor + conn.commit.assert_called_once() + vector.pool.return_connection.assert_called_with(conn) + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + with pytest.raises(RuntimeError, match="boom"): + with vector._get_cursor(): + raise RuntimeError("boom") + conn.rollback.assert_called_once() + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["id-1"]) + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"] + vector._create_collection.assert_called_once_with(2) + + +def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id") + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "generated-id"] + assert cursor.execute.call_count == 2 + + cursor.fetchone.return_value = (1,) + assert vector.text_exists("id-1") is True + cursor.fetchone.return_value = None + assert vector.text_exists("id-2") is False + + vector._get_cursor = MagicMock(side_effect=RuntimeError("db down")) + assert vector.text_exists("id-3") is False + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + before = cursor.execute.call_count + vector.delete_by_ids(["id-1", "id-2"]) + assert cursor.execute.call_count == before + 1 + + vector.delete_by_metadata_field("document_id", "doc-1") + assert "meta LIKE" in cursor.execute.call_args.args[0] + + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.9), + ("id-2", "text-2", '{"document_id":"d-2"}', 0.2), + ("id-x",), + ] + docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + +def test_iris_vector_full_text_search_paths(iris_module, monkeypatch): + cfg = _config(iris_module, IRIS_TEXT_INDEX=True) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", cfg) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + cursor.execute.side_effect = None + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.7), + ("id-2", "text-2", "{}", None), + ] + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.7) + assert docs[1].metadata["score"] == pytest.approx(0.0) + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("rank failed"), None] + cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)] + docs = vector.search_by_full_text("query", top_k=1) + assert len(docs) == 1 + assert cursor.execute.call_count == 2 + + cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector_like = iris_module.IrisVector("collection", cfg_like) + vector_like._get_cursor = _cursor_ctx + + fake_libs = types.ModuleType("libs") + fake_helper = types.ModuleType("libs.helper") + fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%") + monkeypatch.setitem(sys.modules, "libs", fake_libs) + monkeypatch.setitem(sys.modules, "libs.helper", fake_helper) + + cursor.reset_mock() + cursor.execute.side_effect = None + cursor.fetchall.return_value = [] + assert vector_like.search_by_full_text("100%", top_k=1) == [] + + +def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete() + assert "DROP TABLE" in cursor.execute.call_args.args[0] + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(iris_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_called_once() + + cursor.reset_mock() + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None)) + vector.pool.ensure_schema_exists = MagicMock() + vector._create_collection(3) + assert cursor.execute.call_count == 3 + iris_module.redis_client.set.assert_called_once() + + cursor.reset_mock() + vector.config.IRIS_TEXT_INDEX = False + vector._create_collection(3) + assert cursor.execute.call_count == 2 + + factory = iris_module.IrisVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972) + monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user") + monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass") + monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema") + monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url") + monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1) + monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en") + + with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py new file mode 100644 index 0000000000..34357d5907 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py @@ -0,0 +1,394 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearch_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = SimpleNamespace( + refresh=MagicMock(), + exists=MagicMock(return_value=False), + delete=MagicMock(), + create=MagicMock(), + ) + self.bulk = MagicMock(return_value={"errors": False, "items": []}) + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.delete_by_query = MagicMock() + self.get = MagicMock(return_value={"_id": "id"}) + self.exists = MagicMock(return_value=True) + + opensearch_helpers.BulkIndexError = BulkIndexError + opensearch_helpers.bulk = MagicMock() + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.helpers = opensearch_helpers + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearch_helpers, + } + + +@pytest.fixture +def lindorm_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.lindorm.lindorm_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.LindormVectorStoreConfig( + hosts="http://localhost:9200", + username="user", + password="pass", + using_ugc=False, + request_timeout=3.0, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("hosts", None, "config URL is required"), + ("username", None, "config USERNAME is required"), + ("password", None, "config PASSWORD is required"), + ], +) +def test_lindorm_config_validation(lindorm_module, field, value, message): + values = _config(lindorm_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + lindorm_module.LindormVectorStoreConfig.model_validate(values) + + +def test_to_opensearch_params_and_init(lindorm_module): + cfg = _config(lindorm_module) + params = cfg.to_opensearch_params() + + assert params["hosts"] == "http://localhost:9200" + assert params["http_auth"] == ("user", "pass") + + vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False) + assert vector._collection_name == "collection" + assert vector.get_type() == lindorm_module.VectorType.LINDORM + + with pytest.raises(ValueError, match="routing_value"): + lindorm_module.LindormVectorStore("c", cfg, using_ugc=True) + + vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE") + assert vector_ugc._routing == "route" + + +def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock()) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + + vector.add_texts(docs, embeddings, batch_size=2, timeout=9) + + assert vector._client.bulk.call_count == 2 + actions = vector._client.bulk.call_args_list[0].args[0] + assert actions[0]["index"]["routing"] == "route" + assert actions[1][lindorm_module.ROUTING_FIELD] == "route" + vector.refresh() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_add_texts_error_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]} + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + vector._client.bulk.side_effect = RuntimeError("bulk failed") + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + +def test_metadata_lookup_and_delete_by_metadata(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + query = vector._client.search.call_args.kwargs["body"] + must_conditions = query["query"]["bool"]["must"] + assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions) + + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"]) + + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-2") + vector.delete_by_ids.assert_not_called() + + +def test_delete_by_ids_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + vector.delete_by_ids([]) + vector._client.indices.exists.assert_not_called() + + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["id-1"]) + + vector._client.indices.exists.return_value = True + vector._client.exists.side_effect = [True, False] + lindorm_module.helpers.bulk.reset_mock() + vector.delete_by_ids(["id-1", "id-2"]) + lindorm_module.helpers.bulk.assert_called_once() + actions = lindorm_module.helpers.bulk.call_args.args[1] + assert len(actions) == 1 + assert actions[0]["routing"] == "route" + + lindorm_module.helpers.bulk.reset_mock() + lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError( + errors=[ + {"delete": {"status": 404, "_id": "id-404"}}, + {"delete": {"status": 500, "_id": "id-500"}}, + ] + ) + vector._client.exists.side_effect = [True] + vector.delete_by_ids(["id-1"]) + + +def test_delete_and_text_exists(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.delete() + vector._client.delete_by_query.assert_called_once() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.indices.exists.return_value = True + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60}) + + vector._client.indices.delete.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete() + vector._client.indices.delete.assert_not_called() + + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("missing") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validation_and_success(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + with pytest.raises(ValueError, match="should be a list"): + vector.search_by_vector("bad") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, "bad"]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-b", + lindorm_module.Field.VECTOR: [0.2], + lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + call_kwargs = vector._client.search.call_args.kwargs + query = call_kwargs["body"] + assert "ext" in query + assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"] + assert call_kwargs["params"]["routing"] == "route" + + vector._client.search.side_effect = RuntimeError("search failed") + with pytest.raises(RuntimeError, match="search failed"): + vector.search_by_vector([0.1]) + + +def test_search_by_full_text_success_and_error(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1"}, + } + } + ] + } + } + + docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] + + vector._client.search.side_effect = RuntimeError("full text failed") + with pytest.raises(RuntimeError, match="full text failed"): + vector.search_by_full_text("hello") + + +def test_create_collection_paths(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + + with pytest.raises(ValueError, match="cannot be empty"): + vector.create_collection([]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"}) + vector._client.indices.create.assert_called_once() + body = vector._client.indices.create.call_args.kwargs["body"] + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf" + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine" + + vector._client.indices.create.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + +def test_lindorm_factory_branches(lindorm_module, monkeypatch): + factory = lindorm_module.LindormVectorStoreFactory() + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2") + monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={}) + embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3]) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None) + with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"): + factory.init_vector(dataset, attributes=[], embeddings=embeddings) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + + dataset_existing_plain = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False}, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings) + assert result == "vector" + assert store_cls.call_args.args[0] == "existing" + + dataset_existing_ugc = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={ + "vector_store": {"class_prefix": "ROUTING"}, + "using_ugc": True, + "dimension": 1536, + "index_type": "hnsw", + "distance_type": "l2", + }, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "ROUTING" + + dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={}) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "auto_collection" + assert dataset_new.index_struct is not None + + dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={}) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "auto_collection" + assert store_cls.call_args.kwargs["routing_value"] is None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py new file mode 100644 index 0000000000..55e7b9112e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py @@ -0,0 +1,252 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_mo_vector_modules(): + mo_vector = types.ModuleType("mo_vector") + mo_vector.__path__ = [] + mo_vector_client = types.ModuleType("mo_vector.client") + + class MoVectorClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_full_text_index = MagicMock() + self.insert = MagicMock() + self.get = MagicMock(return_value=[]) + self.delete = MagicMock() + self.query_by_metadata = MagicMock(return_value=[]) + self.query = MagicMock(return_value=[]) + self.full_text_query = MagicMock(return_value=[]) + + mo_vector_client.MoVectorClient = MoVectorClient + mo_vector.client = mo_vector_client + return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client} + + +@pytest.fixture +def matrixone_module(monkeypatch): + for name, module in _build_fake_mo_vector_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.matrixone.matrixone_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.MatrixoneConfig( + host="localhost", + port=6001, + user="dump", + password="111", + database="dify", + metric="l2", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config host is required"), + ("port", 0, "config port is required"), + ("user", "", "config user is required"), + ("password", "", "config password is required"), + ("database", "", "config database is required"), + ], +) +def test_matrixone_config_validation(matrixone_module, field, value, message): + values = _valid_config(matrixone_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + matrixone_module.MatrixoneConfig.model_validate(values) + + +def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client.kwargs["table_name"] == "collection_1" + client.create_full_text_index.assert_called_once() + matrixone_module.redis_client.set.assert_called_once() + + +def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + client.create_full_text_index.assert_not_called() + matrixone_module.redis_client.set.assert_not_called() + + +def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = None + fake_client = MagicMock() + fake_client.get.return_value = [{"id": "seg-1"}] + vector._get_client = MagicMock(return_value=fake_client) + + exists = vector.text_exists("seg-1") + + assert exists is True + vector._get_client.assert_called_once_with(None, False) + + +def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.full_text_query.return_value = [ + SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}} + + +def test_get_type_and_create_delegate_to_add_texts(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + fake_client = MagicMock() + vector._get_client = MagicMock(return_value=fake_client) + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "matrixone" + assert result == ["seg-1"] + vector._get_client.assert_called_once_with(2, True) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + failing_client = MagicMock() + failing_client.create_full_text_index.side_effect = RuntimeError("boom") + monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client)) + + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client is failing_client + matrixone_module.redis_client.set.assert_not_called() + + +def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + # For current prod code, only docs with metadata get ids, so only two ids + assert ids == ["doc-a", "generated-uuid"] + vector.client.insert.assert_called_once() + insert_kwargs = vector.client.insert.call_args.kwargs + # All lists passed to insert should be the same length + texts = insert_kwargs["texts"] + embeddings = insert_kwargs["embeddings"] + metadatas = insert_kwargs["metadatas"] + ids_insert = insert_kwargs["ids"] + assert len(texts) == len(embeddings) == len(metadatas) == len(docs) + # ids may be shorter than docs for current prod code, but should match number of docs with metadata + assert ids_insert == ["doc-a", "generated-uuid"] + + +def test_delete_and_metadata_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")] + + vector.delete_by_ids([]) + vector.client.delete.assert_not_called() + + vector.delete_by_ids(["seg-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + vector.delete() + + assert ids == ["seg-1", "seg-2"] + assert vector.client.delete.call_count == 3 + + +def test_search_by_vector_builds_documents(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query.return_value = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}), + ] + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"]) + + assert len(docs) == 2 + assert docs[0].page_content == "doc-a" + assert docs[1].metadata["doc_id"] == "2" + assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}} + + +def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch): + factory = matrixone_module.MatrixoneVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001) + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2") + + with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index fb2ddfe162..2ac2c40d38 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -1,18 +1,414 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + import pytest from pydantic import ValidationError -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig +from core.rag.models.document import Document -def test_default_value(): +def _build_fake_pymilvus_modules(): + pymilvus = types.ModuleType("pymilvus") + pymilvus.__path__ = [] + pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client") + pymilvus_orm = types.ModuleType("pymilvus.orm") + pymilvus_orm.__path__ = [] + pymilvus_orm_types = types.ModuleType("pymilvus.orm.types") + + class MilvusError(Exception): + pass + + class MilvusClient: + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.has_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock( + return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]} + ) + self.get_server_version = MagicMock(return_value="2.5.0") + self.insert = MagicMock(return_value=[1]) + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.drop_collection = MagicMock() + self.search = MagicMock(return_value=[[]]) + self.create_collection = MagicMock() + + class IndexParams: + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + class DataType: + JSON = "JSON" + VARCHAR = "VARCHAR" + INT64 = "INT64" + SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR" + FLOAT_VECTOR = "FLOAT_VECTOR" + + class FieldSchema: + def __init__(self, name, dtype, **kwargs): + self.name = name + self.dtype = dtype + self.kwargs = kwargs + + class CollectionSchema: + def __init__(self, fields): + self.fields = fields + self.functions = [] + + def add_function(self, func): + self.functions.append(func) + + class FunctionType: + BM25 = "BM25" + + class Function: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def infer_dtype_bydata(_value): + return DataType.FLOAT_VECTOR + + pymilvus.MilvusException = MilvusError + pymilvus.MilvusClient = MilvusClient + pymilvus.IndexParams = IndexParams + pymilvus.CollectionSchema = CollectionSchema + pymilvus.DataType = DataType + pymilvus.FieldSchema = FieldSchema + pymilvus.Function = Function + pymilvus.FunctionType = FunctionType + pymilvus_milvus_client.IndexParams = IndexParams + pymilvus_orm.types = pymilvus_orm_types + pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata + + # Attach submodules for dotted imports + pymilvus.milvus_client = pymilvus_milvus_client + pymilvus.orm = pymilvus_orm + + return { + "pymilvus": pymilvus, + "pymilvus.milvus_client": pymilvus_milvus_client, + "pymilvus.orm": pymilvus_orm, + "pymilvus.orm.types": pymilvus_orm_types, + } + + +@pytest.fixture +def milvus_module(monkeypatch): + for name, module in _build_fake_pymilvus_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.milvus.milvus_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "uri": "http://localhost:19530", + "user": "root", + "password": "Milvus", + "database": "default", + "enable_hybrid_search": False, + "analyzer_params": None, + } + values.update(overrides) + return module.MilvusConfig.model_validate(values) + + +def test_config_validation_and_defaults(milvus_module): valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig.model_validate(config) + milvus_module.MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig.model_validate(valid_config) + config = milvus_module.MilvusConfig.model_validate(valid_config) assert config.database == "default" + + token_config = milvus_module.MilvusConfig.model_validate( + {"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"} + ) + assert token_config.token == "token-value" + + +def test_config_to_milvus_params(milvus_module): + config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + + params = config.to_milvus_params() + + assert params["uri"] == "http://localhost:19530" + assert params["db_name"] == "default" + assert params["analyzer_params"] == '{"tokenizer":"standard"}' + + +def test_init_client_supports_token_and_user_password(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + token_client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"} + + user_client = vector._init_client(_config(milvus_module)) + assert user_client.init_kwargs["uri"] == "http://localhost:19530" + assert user_client.init_kwargs["user"] == "root" + assert user_client.init_kwargs["password"] == "Milvus" + + +def test_init_loads_fields_when_collection_exists(milvus_module): + client = milvus_module.MilvusClient(uri="http://localhost:19530") + client.has_collection.return_value = True + client.describe_collection.return_value = { + "fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}] + } + + with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client): + with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False): + vector = milvus_module.MilvusVector("collection_1", _config(milvus_module)) + + assert "id" not in vector._fields + assert "content" in vector._fields + + +def test_load_collection_fields_from_argument_and_remote(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + vector._collection_name = "collection_1" + vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]} + + vector._load_collection_fields(["id", "metadata"]) + assert vector._fields == ["metadata"] + + vector._load_collection_fields() + assert vector._fields == ["content"] + + +def test_check_hybrid_search_support_branches(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + + vector._client_config = SimpleNamespace(enable_hybrid_search=False) + assert vector._check_hybrid_search_support() is False + + vector._client_config = SimpleNamespace(enable_hybrid_search=True) + vector._client.get_server_version.return_value = "Zilliz Cloud 2.4" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.5.1" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.4.9" + assert vector._check_hybrid_search_support() is False + + vector._client.get_server_version.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_get_type_and_create_delegate(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [SimpleNamespace(page_content="hello", metadata=None)] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "milvus" + vector.create_collection.assert_called_once() + create_args = vector.create_collection.call_args.args + assert create_args[0] == [[0.1, 0.2]] + assert create_args[1] == [{}] + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_and_raises_milvus_exception(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.insert.side_effect = [["id-1"], ["id-2"]] + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + ids = vector.add_texts(docs, embeddings) + assert ids == ["id-1", "id-2"] + assert vector._client.insert.call_count == 2 + + vector._client.insert.side_effect = milvus_module.MilvusException("insert failed") + with pytest.raises(milvus_module.MilvusException): + vector.add_texts([Document(page_content="x", metadata={})], [[0.1]]) + + +def test_get_ids_and_delete_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.query.return_value = [{"id": 1}, {"id": 2}] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2] + vector._client.query.return_value = [] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector._client.has_collection.return_value = True + vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102]) + + vector._client.delete.reset_mock() + vector._client.query.return_value = [{"id": 11}, {"id": 12}] + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12]) + + vector._client.has_collection.return_value = True + vector.delete() + vector._client.drop_collection.assert_called_once_with("collection_1", None) + + +def test_text_exists_and_field_exists(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._fields = ["content", "metadata"] + vector._client = MagicMock() + vector._client.has_collection.return_value = False + assert vector.text_exists("doc-1") is False + + vector._client.has_collection.return_value = True + vector._client.query.return_value = [{"id": 1}] + assert vector.text_exists("doc-1") is True + vector._client.query.return_value = [] + assert vector.text_exists("doc-1") is False + assert vector.field_exists("content") is True + assert vector.field_exists("unknown") is False + + +def test_process_search_results_and_search_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._fields = ["content", "metadata", "sparse_vector"] + + processed = vector._process_search_results( + [ + [ + {"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9}, + {"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2}, + ] + ], + [milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY], + score_threshold=0.5, + ) + assert len(processed) == 1 + assert processed[0].metadata["score"] == 0.9 + + vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1) + assert docs == ["doc"] + assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]' + + vector._hybrid_search_enabled = False + assert vector.search_by_full_text("query") == [] + + vector._hybrid_search_enabled = True + vector._fields = [] + assert vector.search_by_full_text("query") == [] + + vector._fields = [milvus_module.Field.SPARSE_VECTOR] + vector._process_search_results = MagicMock(return_value=["full-text-doc"]) + full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2) + assert full_text_docs == ["full-text-doc"] + assert "document_id" in vector._client.search.call_args.kwargs["filter"] + + +def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client_config = _config(milvus_module) + vector._hybrid_search_enabled = False + vector._client = MagicMock() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.has_collection.return_value = True + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + milvus_module.redis_client.set.assert_called() + + +def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client = MagicMock() + vector._client.has_collection.return_value = False + vector._load_collection_fields = MagicMock() + + vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + vector._hybrid_search_enabled = True + vector.create_collection( + embeddings=[[0.1, 0.2]], + metadatas=[{"doc_id": "1"}], + index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}}, + ) + + call_kwargs = vector._client.create_collection.call_args.kwargs + schema = call_kwargs["schema"] + index_params_obj = call_kwargs["index_params"] + field_names = [f.name for f in schema.fields] + + assert milvus_module.Field.SPARSE_VECTOR in field_names + assert len(schema.functions) == 1 + assert len(index_params_obj.indexes) == 2 + assert call_kwargs["consistency_level"] == "Session" + + +def test_factory_initializes_milvus_vector(milvus_module, monkeypatch): + factory = milvus_module.MilvusVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}') + + with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py new file mode 100644 index 0000000000..a75ba82238 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py @@ -0,0 +1,230 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickhouse_connect_module(): + clickhouse_connect = types.ModuleType("clickhouse_connect") + + class QueryResult: + def __init__(self, rows=None, named_rows=None): + self.row_count = len(rows or []) + self.result_rows = rows or [] + self._named_rows = named_rows or [] + + def named_results(self): + return self._named_rows + + class Client: + def __init__(self): + self.command = MagicMock() + self.query = MagicMock(return_value=QueryResult()) + + client = Client() + + def get_client(**_kwargs): + return client + + clickhouse_connect.get_client = get_client + clickhouse_connect.QueryResult = QueryResult + clickhouse_connect._fake_client = client + return clickhouse_connect + + +@pytest.fixture +def myscale_module(monkeypatch): + fake_module = _build_fake_clickhouse_connect_module() + monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) + + import core.rag.datasource.vdb.myscale.myscale_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ) + + +def test_escape_str_replaces_backslash_and_quote(myscale_module): + escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special") + assert escaped == "text with special" + + +def test_search_raises_for_invalid_top_k(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0) + + +def test_search_builds_where_clause_for_cosine_threshold(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__( + named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}] + ) + + docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2) + + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "WHERE dist < 0.8" in sql + + +def test_delete_by_ids_short_circuits_on_empty_list(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete_by_ids([]) + vector._client.command.assert_not_called() + + +def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch): + factory = myscale_module.MyScaleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123) + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "") + + with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +def test_init_and_get_type_set_expected_defaults(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + + assert vector.get_type() == "myscale" + assert vector._vec_order == myscale_module.SortOrder.ASC + vector._client.command.assert_called_with("SET allow_experimental_object_type=1") + + +def test_create_calls_create_collection_and_add_texts(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once() + + +def test_create_collection_builds_expected_sql(myscale_module): + config = myscale_module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="tokenizer=unicode", + ) + vector = myscale_module.MyScaleVector("collection_1", config) + vector._client.command.reset_mock() + + vector._create_collection(3) + + assert vector._client.command.call_count == 2 + sql = vector._client.command.call_args_list[1].args[0] + assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql + assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql + assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql + + +def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="text-2", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="text-3", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + sql = vector._client.command.call_args.args[0] + assert "INSERT INTO dify.collection_1" in sql + assert "te xt 1" in sql + + +def test_text_exists_and_metadata_operations(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)]) + + assert vector.text_exists("id-1") is True + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._client.command.call_count >= 2 + + +def test_search_delegation_methods(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._search = MagicMock(return_value=["result"]) + + result_vector = vector.search_by_vector([0.1, 0.2], top_k=2) + result_text = vector.search_by_full_text("hello", top_k=2) + + assert result_vector == ["result"] + assert result_text == ["result"] + assert vector._search.call_count == 2 + + +def test_search_with_document_filter_and_exception(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace( + named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}] + ) + + docs = vector._search( + "distance(vector, [0.1])", + myscale_module.SortOrder.ASC, + top_k=2, + document_ids_filter=["doc-1", "doc-2"], + ) + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql + + vector._client.query.side_effect = RuntimeError("boom") + assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == [] + + +def test_delete_drops_table(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete() + + vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py new file mode 100644 index 0000000000..27d8198ec0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py @@ -0,0 +1,553 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from core.rag.models.document import Document + + +def _build_fake_pyobvector_module(): + pyobvector = types.ModuleType("pyobvector") + + class VECTOR: + def __init__(self, dim): + self.dim = dim + + def l2_distance(*_args, **_kwargs): + return "l2" + + def cosine_distance(*_args, **_kwargs): + return "cosine" + + def inner_product(*_args, **_kwargs): + return "inner_product" + + class ObVecClient: + def __init__(self, **_kwargs): + self.metadata_obj = SimpleNamespace(tables={}) + self.engine = MagicMock() + self.check_table_exists = MagicMock(return_value=False) + self.perform_raw_text_sql = MagicMock() + self.prepare_index_params = MagicMock() + self.create_table_with_index_params = MagicMock() + self.refresh_metadata = MagicMock() + self.insert = MagicMock() + self.refresh_index = MagicMock() + self.get = MagicMock() + self.delete = MagicMock() + self.set_ob_hnsw_ef_search = MagicMock() + self.ann_search = MagicMock(return_value=[]) + self.drop_table_if_exist = MagicMock() + + pyobvector.VECTOR = VECTOR + pyobvector.ObVecClient = ObVecClient + pyobvector.l2_distance = l2_distance + pyobvector.cosine_distance = cosine_distance + pyobvector.inner_product = inner_product + return pyobvector + + +@pytest.fixture +def oceanbase_module(monkeypatch): + monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) + + import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.OceanBaseVectorConfig( + host="127.0.0.1", + port=2881, + user="root", + password="secret", + database="test", + enable_hybrid_search=True, + batch_size=10, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OCEANBASE_VECTOR_HOST is required"), + ("port", 0, "config OCEANBASE_VECTOR_PORT is required"), + ("user", "", "config OCEANBASE_VECTOR_USER is required"), + ("database", "", "config OCEANBASE_VECTOR_DATABASE is required"), + ], +) +def test_oceanbase_config_validation(oceanbase_module, field, value, message): + values = _config(oceanbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oceanbase_module.OceanBaseVectorConfig.model_validate(values) + + +def test_init_rejects_invalid_collection_name(oceanbase_module): + with pytest.raises(ValueError, match="Invalid collection name"): + oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module)) + + +def test_distance_to_score_for_supported_metrics(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="l2") + assert vector._distance_to_score(3.0) == pytest.approx(0.25) + + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._distance_to_score(0.2) == pytest.approx(0.8) + + vector._config = SimpleNamespace(metric_type="inner_product") + assert vector._distance_to_score(-0.2) == pytest.approx(0.2) + + +def test_get_distance_func_raises_for_unknown_metric(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="manhattan") + + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._get_distance_func() + + +def test_process_search_results_handles_json_and_score_threshold(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + rows = [ + ("doc-1", '{"doc_id":"1"}', 0.9), + ("doc-2", "not-json", 0.8), + ("doc-3", {"doc_id": "3"}, 0.3), + ] + + docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank") + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["rank"] == 0.9 + assert docs[1].metadata["rank"] == 0.8 + + +def test_search_by_vector_validates_document_id_format(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + + with pytest.raises(ValueError, match="Invalid document ID format"): + vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"]) + + +def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._hybrid_search_enabled = False + vector._collection_name = "collection_1" + + assert vector.search_by_full_text("query") == [] + + +def test_check_hybrid_search_support_uses_version_comment(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client = MagicMock() + cursor = MagicMock() + cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",) + vector._client.perform_raw_text_sql.return_value = cursor + + assert vector._check_hybrid_search_support() is True + + cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",) + assert vector._check_hybrid_search_support() is False + + +def test_init_get_type_and_field_loading(oceanbase_module): + config = _config(oceanbase_module) + config.enable_hybrid_search = False + + table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")]) + fake_client = oceanbase_module.ObVecClient() + fake_client.check_table_exists.return_value = True + fake_client.metadata_obj.tables = {"collection_1": table} + + with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client): + vector = oceanbase_module.OceanBaseVector("collection_1", config) + + assert vector.get_type() == "oceanbase" + assert vector.field_exists("text") is True + + +def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._fields = [] + vector._client = MagicMock() + vector._client.metadata_obj.tables = {} + + vector._load_collection_fields() + assert vector._fields == [] + + vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))} + vector._load_collection_fields() + assert vector._fields == [] + + +def test_create_delegates_to_collection_and_insert(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector._vec_dim == 2 + vector._create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = False + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection() + vector._client.check_table_exists.assert_not_called() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.check_table_exists.return_value = True + vector._create_collection() + vector.delete.assert_not_called() + + +def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 3 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + None, + None, + ] + index_params = MagicMock() + vector._client.prepare_index_params.return_value = index_params + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._create_collection() + + vector.delete.assert_called_once() + vector._client.create_table_with_index_params.assert_called_once() + index_params.add_index.assert_called_once() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + oceanbase_module.redis_client.set.assert_called_once() + + +def test_create_collection_error_paths(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.return_value = [] + with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "0"]], + RuntimeError("no privilege"), + ] + with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]] + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid") + with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"): + vector._create_collection() + + +def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + RuntimeError("fulltext failed"), + ] + with pytest.raises(Exception, match="Failed to add fulltext index"): + vector._create_collection() + + vector._hybrid_search_enabled = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + SQLAlchemyError("metadata index failed"), + ] + vector._create_collection() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + + +def test_check_hybrid_search_support_false_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=False) + vector._client = MagicMock() + assert vector._check_hybrid_search_support() is False + + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_add_texts_batches_refresh_and_exceptions(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2) + vector._client = MagicMock() + vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + + vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert vector._client.insert.call_count == 2 + vector._client.refresh_index.assert_called_once() + + vector._client.insert.reset_mock() + vector._client.refresh_index.reset_mock() + vector._client.insert.side_effect = RuntimeError("insert failed") + with pytest.raises(Exception, match="Failed to insert batch"): + vector.add_texts([docs[0]], [[0.1]]) + + vector._client.insert.side_effect = None + vector._client.insert.return_value = None + vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed") + vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1) + vector._get_uuids.return_value = ["id-1"] + vector.add_texts([docs[0]], [[0.1]]) + + +def test_text_exists_and_delete_by_ids(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.get.return_value = SimpleNamespace(rowcount=1) + assert vector.text_exists("id-1") is True + + vector._client.get.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to check text existence"): + vector.text_exists("id-1") + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + + vector._client.delete.side_effect = None + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once() + + vector._client.delete.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete documents"): + vector.delete_by_ids(["id-1"]) + + +def test_get_ids_and_delete_by_metadata_field(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + execute_result = [("id-1",), ("id-2",)] + + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value = execute_result + vector._client.engine.connect.return_value = conn + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + + with pytest.raises(Exception, match="Failed to query documents by metadata field"): + vector.get_ids_by_metadata_field("bad key!", "doc-1") + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_not_called() + + +def test_search_by_full_text_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hybrid_search_enabled = True + vector.field_exists = MagicMock(return_value=False) + + assert vector.search_by_full_text("query") == [] + + vector.field_exists.return_value = True + vector._client = MagicMock() + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = tx + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)] + vector._client.engine.connect.return_value = conn + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + with pytest.raises(Exception, match="Full-text search failed"): + vector.search_by_full_text("query", top_k=0) + + +def test_search_by_vector_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector( + [0.1, 0.2], + ef_search=10, + top_k=3, + score_threshold=0.1, + document_ids_filter=["good_id"], + ) + assert docs == ["doc"] + vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10) + + with pytest.raises(ValueError, match="Invalid score_threshold parameter"): + vector.search_by_vector([0.1], score_threshold="x") + + vector._client.ann_search.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Vector search failed"): + vector.search_by_vector([0.1], score_threshold=0.1) + + +def test_get_distance_func_and_distance_to_score_errors(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._get_distance_func() is oceanbase_module.cosine_distance + + vector._config = SimpleNamespace(metric_type="unknown") + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._distance_to_score(0.1) + + +def test_delete_success_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector.delete() + vector._client.drop_table_if_exist.assert_called_once_with("collection_1") + + vector._client.drop_table_if_exist.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete collection"): + vector.delete() + + +def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch): + factory = oceanbase_module.OceanBaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000) + + with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].args[0] == "existing_collection" + assert vector_cls.call_args_list[1].args[0] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py new file mode 100644 index 0000000000..6641dbe4a0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py @@ -0,0 +1,400 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def opengauss_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opengauss.opengauss as module + + return importlib.reload(module) + + +def _config(module, *, enable_pq=False): + return module.OpenGaussConfig( + host="localhost", + port=6600, + user="postgres", + password="password", + database="dify", + min_connection=1, + max_connection=5, + enable_pq=enable_pq, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENGAUSS_HOST is required"), + ("port", 0, "config OPENGAUSS_PORT is required"), + ("user", "", "config OPENGAUSS_USER is required"), + ("password", "", "config OPENGAUSS_PASSWORD is required"), + ("database", "", "config OPENGAUSS_DATABASE is required"), + ("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"), + ("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"), + ], +) +def test_opengauss_config_validation(opengauss_module, field, value, message): + values = _config(opengauss_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module): + values = _config(opengauss_module).model_dump() + values["min_connection"] = 6 + values["max_connection"] = 5 + + with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + + assert vector.table_name == "embedding_collection_1" + assert vector.get_type() == "opengauss" + assert vector.pool is pool + + +def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("enable_pq=on" in sql for sql in executed_sql) + assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql) + opengauss_module.redis_client.set.assert_called_once() + + +def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(3072) + + cursor.execute.assert_not_called() + opengauss_module.redis_client.set.assert_called_once() + + +def test_search_by_vector_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + +def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector.delete_by_ids([]) + + vector._get_cursor.assert_not_called() + + +def test_get_cursor_closes_commits_and_returns_connection(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + pool = MagicMock() + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + vector.pool = pool + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_calls_collection_insert_and_index(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + vector._create_index = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + vector._create_index.assert_called_once_with(2) + + +def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector._create_index(1536) + + vector._get_cursor.assert_not_called() + opengauss_module.redis_client.set.assert_not_called() + + +def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql) + + +def test_add_texts_uses_execute_values(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + opengauss_module.psycopg2.extras.execute_values.reset_mock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + docs = [ + Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}), + SimpleNamespace(page_content="text-2", metadata=None), + ] + monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid") + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["seg-1"] + opengauss_module.psycopg2.extras.execute_values.assert_called_once() + + +def test_text_exists_and_get_by_ids(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("seg-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("seg-1") is True + docs = vector.get_by_ids(["seg-1", "seg-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + +def test_delete_and_metadata_field_queries(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + vector.delete_by_ids(["seg-1", "seg-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql) + assert any("meta->>%s = %s" in query for query in sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql) + + +def test_search_by_vector_and_full_text(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.6), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_search_by_full_text_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=0) + + +def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(1536) + cursor.execute.assert_not_called() + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(1536) + cursor.execute.assert_called_once() + opengauss_module.redis_client.set.assert_called_once() + + +def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch): + factory = opengauss_module.OpenGaussFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False) + + with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py new file mode 100644 index 0000000000..1030158dd1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py @@ -0,0 +1,360 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "basic", + "user": "admin", + "password": "secret", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENSEARCH_HOST is required"), + ("port", 0, "config OPENSEARCH_PORT is required"), + ], +) +def test_config_validation_required_fields(opensearch_module, field, value, message): + values = _config(opensearch_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_config_validation_for_aws_auth_and_https_fields(opensearch_module): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "aws_managed_iam", + "user": "admin", + "password": "secret", + } + with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"): + opensearch_module.OpenSearchConfig.model_validate(values) + + values = _config(opensearch_module).model_dump() + values["OPENSEARCH_SECURE"] = False + values["OPENSEARCH_VERIFY_CERTS"] = True + with pytest.raises(ValidationError, match="verify_certs=True requires secure"): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-east-1", + aws_service="es", + ) + auth = config.create_aws_managed_iam_auth() + + assert auth.credentials == "creds" + assert auth.region == "us-east-1" + assert auth.service == "es" + + +def test_to_opensearch_params_supports_basic_and_aws(opensearch_module): + basic_params = _config(opensearch_module).to_opensearch_params() + assert basic_params["http_auth"] == ("admin", "secret") + + aws_config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-west-2", + aws_service="es", + ) + with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"): + aws_params = aws_config.to_opensearch_params() + + assert aws_params["http_auth"] == "iam-auth" + + +def test_init_and_create_delegate_calls(opensearch_module): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "opensearch" + vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es")) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id")) + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.1], [0.2]]) + actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert len(actions) == 2 + assert all("_id" in action for action in actions) + + vector._client_config.aws_service = "aoss" + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.3], [0.4]]) + aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert all("_id" not in action for action in aoss_actions) + + +def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector._client.search.return_value = {"hits": {"hits": []}} + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + +def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + opensearch_module.helpers.bulk.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["doc-1"]) + opensearch_module.helpers.bulk.assert_not_called() + + vector._client.indices.exists.return_value = True + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None]) + vector.delete_by_ids(["doc-1", "doc-2"]) + opensearch_module.helpers.bulk.assert_called_once() + + opensearch_module.helpers.bulk.reset_mock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"]) + opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError( + [{"delete": {"status": 404, "_id": "es-404"}}] + ) + vector.delete_by_ids(["doc-404"]) + assert opensearch_module.helpers.bulk.call_count == 1 + + opensearch_module.helpers.bulk.side_effect = None + + +def test_delete_and_text_exists(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True) + + vector._client.get.return_value = {"_id": "id-1"} + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("not found") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validates_and_builds_documents(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + + with pytest.raises(ValueError, match="query_vector should be a list"): + vector.search_by_vector("not-a-list") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, 1]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-1", + opensearch_module.Field.METADATA_KEY: None, + }, + "_score": 0.9, + }, + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-2", + opensearch_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + "_score": 0.1, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"]) + query = vector._client.search.call_args.kwargs["body"] + assert "script_score" in query["query"] + + +def test_search_by_vector_reraises_client_error(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.search_by_vector([0.1, 0.2]) + + +def test_search_by_full_text_and_filters(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.METADATA_KEY: {"doc_id": "1"}, + opensearch_module.Field.VECTOR: [0.1], + opensearch_module.Field.CONTENT_KEY: "matched text", + } + }, + ] + } + } + + docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "matched text" + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}] + + +def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock()) + + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1)) + vector._client.indices.create.reset_mock() + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_called_once() + index_body = vector._client.indices.create.call_args.kwargs["body"] + assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2 + opensearch_module.redis_client.set.assert_called() + + +def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch): + factory = opensearch_module.OpenSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None) + + with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py new file mode 100644 index 0000000000..817a7d342b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py @@ -0,0 +1,375 @@ +import array +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_oracle_modules(): + jieba = types.ModuleType("jieba") + jieba_posseg = types.ModuleType("jieba.posseg") + jieba_posseg.cut = MagicMock(return_value=[]) + jieba.posseg = jieba_posseg + + oracledb = types.ModuleType("oracledb") + oracledb_connection = types.ModuleType("oracledb.connection") + + class Connection: + pass + + oracledb_connection.Connection = Connection + oracledb.defaults = SimpleNamespace(fetch_lobs=True) + oracledb.DB_TYPE_VECTOR = object() + oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock())) + oracledb.connect = MagicMock() + + return { + "jieba": jieba, + "jieba.posseg": jieba_posseg, + "oracledb": oracledb, + "oracledb.connection": oracledb_connection, + } + + +def _connection_with_cursor(cursor): + cursor_ctx = MagicMock() + cursor_ctx.__enter__.return_value = cursor + cursor_ctx.__exit__.return_value = None + + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = None + connection.cursor.return_value = cursor_ctx + return connection + + +@pytest.fixture +def oracle_module(monkeypatch): + for name, module in _build_fake_oracle_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.oracle.oraclevector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "user": "system", + "password": "oracle", + "dsn": "oracle:1521/freepdb1", + "is_autonomous": False, + } + values.update(overrides) + return module.OracleVectorConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("user", "", "config ORACLE_USER is required"), + ("password", "", "config ORACLE_PASSWORD is required"), + ("dsn", "", "config ORACLE_DSN is required"), + ], +) +def test_oracle_config_validation_required_fields(oracle_module, field, value, message): + values = _config(oracle_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oracle_module.OracleVectorConfig.model_validate(values) + + +def test_oracle_config_validation_autonomous_requirements(oracle_module): + with pytest.raises(ValidationError, match="config_dir is required"): + oracle_module.OracleVectorConfig.model_validate( + {"user": "u", "password": "p", "dsn": "d", "is_autonomous": True} + ) + + +def test_init_and_get_type(oracle_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool)) + vector = oracle_module.OracleVector("collection_1", _config(oracle_module)) + + assert vector.get_type() == "oracle" + assert vector.table_name == "embedding_collection_1" + assert vector.pool is pool + + +def test_numpy_converters_and_type_handlers(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + + in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64)) + in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32)) + in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8)) + assert in_float64.typecode == "d" + assert in_float32.typecode == "f" + assert in_int8.typecode == "b" + + cursor = MagicMock() + vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2) + cursor.var.assert_called_with( + oracle_module.oracledb.DB_TYPE_VECTOR, + arraysize=2, + inconverter=vector.numpy_converter_in, + ) + + metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR) + cursor.arraysize = 3 + vector.output_type_handler(cursor, metadata) + cursor.var.assert_called_with( + metadata.type_code, + arraysize=3, + outconverter=vector.numpy_converter_out, + ) + + out_int8 = vector.numpy_converter_out(array.array("b", [1])) + assert out_int8.dtype == numpy.int8 + out_float32 = vector.numpy_converter_out(array.array("f", [1.0])) + assert out_float32.dtype == numpy.float32 + out_float64 = vector.numpy_converter_out(array.array("d", [1.0])) + assert out_float64.dtype == numpy.float64 + + +def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch): + connect = MagicMock(return_value="connection") + monkeypatch.setattr(oracle_module.oracledb, "connect", connect) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.config = _config(oracle_module) + assert vector._get_connection() == "connection" + connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1") + + vector.config = _config( + oracle_module, + is_autonomous=True, + config_dir="/wallet", + wallet_location="/wallet", + wallet_password="pw", + ) + vector._get_connection() + assert connect.call_args.kwargs["config_dir"] == "/wallet" + assert connect.call_args.kwargs["wallet_location"] == "/wallet" + + +def test_create_delegates_collection_and_insert(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.execute.side_effect = [None, RuntimeError("insert failed")] + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + assert cursor.execute.call_count == 2 + assert connection.commit.call_count >= 1 + connection.close.assert_called() + + +def test_text_exists_and_get_by_ids(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.pool = MagicMock() + + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + vector.pool.release.assert_called_once() + assert vector.get_by_ids([]) == [] + + +def test_delete_methods(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + vector.delete_by_ids([]) + vector._get_connection.assert_not_called() + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql) + assert any("JSON_VALUE(meta" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_with_threshold_and_filter(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)]) + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=0, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "fetch first 4 rows only" in sql + assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql + + +def _fake_nltk_module(*, missing_data=False): + nltk = types.ModuleType("nltk") + nltk_corpus = types.ModuleType("nltk.corpus") + + class _Data: + @staticmethod + def find(_path): + if missing_data: + raise LookupError("missing") + return True + + nltk.data = _Data() + nltk.word_tokenize = lambda text: text.split() + nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"]) + return nltk, nltk_corpus + + +def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")])) + zh_docs = vector.search_by_full_text("张三", top_k=2) + assert len(zh_docs) == 1 + zh_params = cursor.execute.call_args.args[1] + assert zh_params["kk"] == "张三" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=False) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])]) + en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"]) + assert len(en_docs) == 1 + en_sql = cursor.execute.call_args.args[0] + en_params = cursor.execute.call_args.args[1] + assert "fetch first 5 rows only" in en_sql + assert "doc_id_0" in en_params + + +def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector._get_connection = MagicMock() + + empty_result = vector.search_by_full_text("") + assert empty_result[0].page_content == "" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=True) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + with pytest.raises(LookupError, match="required NLTK data package"): + vector.search_by_full_text("english query") + + +def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock()) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_not_called() + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(2) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql) + oracle_module.redis_client.set.assert_called_once() + + +def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch): + factory = oracle_module.OracleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False) + + with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000..1aec81b8ac --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,317 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_pgvecto_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSessionContext: + def __init__(self, calls, execute_results=None): + self.calls = calls + self.execute_results = execute_results or [] + self.execute = MagicMock(side_effect=self._execute_side_effect) + self.commit = MagicMock() + + def _execute_side_effect(self, *args, **kwargs): + self.calls.append((args, kwargs)) + if self.execute_results: + return self.execute_results.pop(0) + return MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(calls, execute_results=None): + def _session(_client): + return _FakeSessionContext(calls=calls, execute_results=execute_results) + + return _session + + +@pytest.fixture +def pgvecto_module(monkeypatch): + for name, module in _build_fake_pgvecto_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module + import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + + return importlib.reload(module), importlib.reload(collection_module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "postgres", + } + values.update(overrides) + return module.PgvectoRSConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config PGVECTO_RS_HOST is required"), + ("port", 0, "config PGVECTO_RS_PORT is required"), + ("user", "", "config PGVECTO_RS_USER is required"), + ("password", "", "config PGVECTO_RS_PASSWORD is required"), + ("database", "", "config PGVECTO_RS_DATABASE is required"), + ], +) +def test_pgvecto_config_validation(pgvecto_module, field, value, message): + module, _ = pgvecto_module + values = _config(module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + module.PgvectoRSConfig.model_validate(values) + + +def test_collection_base_has_expected_annotations(pgvecto_module): + _, collection_module = pgvecto_module + annotations = collection_module.CollectionORM.__annotations__ + assert {"id", "text", "meta", "vector"} <= set(annotations) + + +def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == module.VectorType.PGVECTO_RS + module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres") + assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls) + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(module.redis_client, "set", MagicMock()) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection(3) + assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None)) + vector.create_collection(3) + assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls) + module.redis_client.set.assert_called() + + +def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + runtime_calls = [] + execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] + + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + + class _InsertBuilder: + def __init__(self, table): + self.table = table + + def values(self, **kwargs): + return ("insert", kwargs) + + monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table)) + monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["uuid-1", "uuid-2"] + assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0]) + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("row-id-1",)]), + MagicMock(), + ], + ), + ) + vector.delete_by_ids(["doc-1"]) + assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + vector.delete() + assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) + + +def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + runtime_calls = [] + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("id-1",)]), + SimpleNamespace(fetchall=lambda: []), + ], + ), + ) + assert vector.text_exists("doc-1") is True + assert vector.text_exists("doc-1") is False + + class _DistanceExpr: + def label(self, _name): + return self + + class _VectorColumn: + def op(self, _operator, return_type=None): + def _call(_query_vector): + return _DistanceExpr() + + return _call + + class _MetaFilter: + def in_(self, values): + return ("in", values) + + class _MetaColumn: + def __getitem__(self, _item): + return _MetaFilter() + + class _Stmt: + def __init__(self): + self.where_called = False + + def limit(self, _value): + return self + + def order_by(self, _value): + return self + + def where(self, _value): + self.where_called = True + return self + + stmt = _Stmt() + monkeypatch.setattr(module, "select", lambda *_args: stmt) + + vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn()) + rows = [ + (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), + (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), + ] + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert stmt.where_called is True + assert vector.search_by_full_text("hello") == [] + + +def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + factory = module.PGVectoRSFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432) + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres") + + embeddings = MagicMock() + embeddings.embed_query.return_value = [0.1, 0.2, 0.3] + + with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py index 4998a9858f..7505262eb7 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -1,16 +1,19 @@ -import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module from core.rag.datasource.vdb.pgvector.pgvector import ( PGVector, PGVectorConfig, ) +from core.rag.models.document import Document -class TestPGVector(unittest.TestCase): - def setUp(self): +class TestPGVector: + def setup_method(self, method): self.config = PGVectorConfig( host="localhost", port=5432, @@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override): PGVectorConfig(**config) -if __name__ == "__main__": - unittest.main() +def test_create_delegates_collection_creation_and_insert(): + vector = PGVector.__new__(PGVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["doc-a"]) + docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["doc-a"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid") + execute_values = MagicMock() + monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values) + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + execute_values.assert_called_once() + + +def test_text_get_and_delete_methods(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + class _UndefinedTableError(Exception): + pass + + monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError) + cursor.execute.side_effect = _UndefinedTableError("missing") + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + vector.delete_by_ids(["doc-1"]) + + +def test_search_by_vector_supports_filter_and_threshold(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "meta->>'document_id' in ('d-1')" in sql + + +def test_search_by_full_text_branches_for_bigm_and_standard(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + vector.pg_bigm = False + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.7) + standard_sql = cursor.execute.call_args.args[0] + assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql + + cursor.execute.reset_mock() + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)]) + vector.pg_bigm = True + vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"]) + assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0] + assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0] + + +def test_pgvector_factory_initializes_expected_collection_name(monkeypatch): + factory = pgvector_module.PGVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False) + + with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py new file mode 100644 index 0000000000..bd8df520ba --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py @@ -0,0 +1,269 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def vastbase_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VastbaseVectorConfig( + host="localhost", + port=5432, + user="dify", + password="secret", + database="dify", + min_connection=1, + max_connection=5, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config VASTBASE_HOST is required"), + ("port", 0, "config VASTBASE_PORT is required"), + ("user", "", "config VASTBASE_USER is required"), + ("password", "", "config VASTBASE_PASSWORD is required"), + ("database", "", "config VASTBASE_DATABASE is required"), + ("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"), + ("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"), + ], +) +def test_vastbase_config_validation(vastbase_module, field, value, message): + values = _config(vastbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + vastbase_module.VastbaseVectorConfig.model_validate(values) + + +def test_vastbase_config_rejects_invalid_connection_window(vastbase_module): + with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"): + vastbase_module.VastbaseVectorConfig.model_validate( + { + "host": "localhost", + "port": 5432, + "user": "dify", + "password": "secret", + "database": "dify", + "min_connection": 6, + "max_connection": 5, + } + ) + + +def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + + vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module)) + assert vector.get_type() == "vastbase" + assert vector.table_name == "embedding_collection_1" + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_and_add_texts(vastbase_module, monkeypatch): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + vector._create_collection = MagicMock() + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid") + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert ids == ["doc-a", "generated-uuid"] + vastbase_module.psycopg2.extras.execute_values.assert_called_once() + + vector.add_texts = MagicMock(return_value=["doc-a"]) + result = vector.create(docs, [[0.1], [0.2], [0.3]]) + vector._create_collection.assert_called_once_with(1) + assert result == ["doc-a"] + + +def test_text_get_delete_and_metadata_methods(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_ids([]) + vector.delete_by_ids(["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql) + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_and_full_text(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.8), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock()) + + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + cursor.execute.assert_not_called() + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(17000) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql) + + cursor.execute.reset_mock() + vector._create_collection(3) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql) + vastbase_module.redis_client.set.assert_called() + + +def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch): + factory = vastbase_module.VastbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5) + + with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py new file mode 100644 index 0000000000..0408506563 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py @@ -0,0 +1,328 @@ +import importlib +import os +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_qdrant_modules(): + qdrant_client = types.ModuleType("qdrant_client") + qdrant_http = types.ModuleType("qdrant_client.http") + qdrant_http_models = types.ModuleType("qdrant_client.http.models") + qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions") + qdrant_local_pkg = types.ModuleType("qdrant_client.local") + qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local") + + class UnexpectedResponseError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + class FilterSelector: + def __init__(self, filter): + self.filter = filter + + class HnswConfigDiff: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class TextIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class PointStruct: + def __init__(self, **kwargs): + self.id = kwargs["id"] + self.vector = kwargs["vector"] + self.payload = kwargs["payload"] + + class Filter: + def __init__(self, must=None): + self.must = must or [] + + class FieldCondition: + def __init__(self, key, match): + self.key = key + self.match = match + + class MatchValue: + def __init__(self, value): + self.value = value + + class MatchAny: + def __init__(self, any): + self.any = any + + class MatchText: + def __init__(self, text): + self.text = text + + class _Distance(UserDict): + def __getitem__(self, key): + return key + + class QdrantClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[])) + self.create_collection = MagicMock() + self.create_payload_index = MagicMock() + self.upsert = MagicMock() + self.delete = MagicMock() + self.delete_collection = MagicMock() + self.retrieve = MagicMock(return_value=[]) + self.search = MagicMock(return_value=[]) + self.scroll = MagicMock(return_value=([], None)) + + class QdrantLocal(QdrantClient): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load = MagicMock() + + qdrant_client.QdrantClient = QdrantClient + qdrant_http_models.FilterSelector = FilterSelector + qdrant_http_models.HnswConfigDiff = HnswConfigDiff + qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD") + qdrant_http_models.TextIndexParams = TextIndexParams + qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT") + qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL") + qdrant_http_models.VectorParams = VectorParams + qdrant_http_models.Distance = _Distance() + qdrant_http_models.PointStruct = PointStruct + qdrant_http_models.Filter = Filter + qdrant_http_models.FieldCondition = FieldCondition + qdrant_http_models.MatchValue = MatchValue + qdrant_http_models.MatchAny = MatchAny + qdrant_http_models.MatchText = MatchText + qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError + + qdrant_http.models = qdrant_http_models + qdrant_local_mod.QdrantLocal = QdrantLocal + qdrant_local_pkg.qdrant_local = qdrant_local_mod + + return { + "qdrant_client": qdrant_client, + "qdrant_client.http": qdrant_http, + "qdrant_client.http.models": qdrant_http_models, + "qdrant_client.http.exceptions": qdrant_http_exceptions, + "qdrant_client.local": qdrant_local_pkg, + "qdrant_client.local.qdrant_local": qdrant_local_mod, + } + + +@pytest.fixture +def qdrant_module(monkeypatch): + for name, module in _build_fake_qdrant_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.qdrant.qdrant_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "endpoint": "http://localhost:6333", + "api_key": "api-key", + "timeout": 20, + "root_path": "/tmp", + "grpc_port": 6334, + "prefer_grpc": False, + "replication_factor": 1, + "write_consistency_factor": 1, + } + values.update(overrides) + return module.QdrantConfig.model_validate(values) + + +def test_qdrant_config_to_params(qdrant_module): + url_params = _config(qdrant_module).to_qdrant_params().model_dump() + assert url_params["url"] == "http://localhost:6333" + assert url_params["verify"] is False + + path_config = _config(qdrant_module, endpoint="path:storage") + assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage") + + with pytest.raises(ValueError, match="Root path is not set"): + _config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params() + + +def test_init_and_basic_behaviour(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.get_type() == qdrant_module.VectorType.QDRANT + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1" + + docs = [Document(page_content="a", metadata={"doc_id": "a"})] + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with("collection_1", 1) + vector.add_texts.assert_called_once() + + +def test_create_collection_and_add_texts(qdrant_module, monkeypatch): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.get_collections.return_value = SimpleNamespace(collections=[]) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_called_once() + assert vector._client.create_payload_index.call_count == 4 + qdrant_module.redis_client.set.assert_called_once() + + # add_texts and generated batches + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.upsert.call_count == 1 + + payloads = qdrant_module.QdrantVector._build_payloads( + ["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + assert payloads[0]["group_id"] == "g1" + with pytest.raises(ValueError, match="At least one of the texts is None"): + qdrant_module.QdrantVector._build_payloads( + [None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + + +def test_delete_and_exists_paths(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete() + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete() + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = None + + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")]) + assert vector.text_exists("id-1") is False + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]) + vector._client.retrieve.return_value = [{"id": "id-1"}] + assert vector.text_exists("id-1") is True + + +def test_search_and_helper_methods(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.search_by_vector([0.1], score_threshold=1.0) == [] + + vector._client.search.return_value = [ + SimpleNamespace(payload=None, score=0.9, vector=[0.1]), + SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]), + ] + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + # full text search: keyword split, dedup and top_k limit + scroll_results = [ + ( + [ + SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]), + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ( + [ + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ] + vector._client.scroll.side_effect = scroll_results + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert vector.search_by_full_text(" ", top_k=2) == [] + + local_client = qdrant_module.QdrantLocal() + vector._client = local_client + vector._reload_if_needed() + local_client._load.assert_called_once() + + doc = vector._document_from_scored_point( + SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]), + "page_content", + "metadata", + ) + assert doc.page_content == "doc" + + +def test_qdrant_factory_paths(qdrant_module, monkeypatch): + factory = qdrant_module.QdrantVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + collection_binding_id=None, + index_struct_dict=None, + index_struct=None, + ) + monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root"))) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1) + + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset.index_struct is not None + + # collection binding lookup path + dataset.collection_binding_id = "binding-1" + dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}} + monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + qdrant_module.db.session.scalars = MagicMock( + return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION")) + ) + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION" + + qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None)) + with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py new file mode 100644 index 0000000000..ca8cd5e514 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -0,0 +1,303 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_relyt_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSession: + def __init__(self, execute_result=None): + self.execute_result = execute_result or MagicMock(fetchall=lambda: []) + self.execute = MagicMock(return_value=self.execute_result) + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +@pytest.fixture +def relyt_module(monkeypatch): + for name, module in _build_fake_relyt_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.relyt.relyt_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "relyt", + } + values.update(overrides) + return module.RelytConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config RELYT_HOST is required"), + ("port", 0, "config RELYT_PORT is required"), + ("user", "", "config RELYT_USER is required"), + ("password", "", "config RELYT_PASSWORD is required"), + ("database", "", "config RELYT_DATABASE is required"), + ], +) +def test_relyt_config_validation(relyt_module, field, value, message): + values = _config(relyt_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + relyt_module.RelytConfig.model_validate(values) + + +def test_init_get_type_and_create_delegate(relyt_module, monkeypatch): + engine = MagicMock() + monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine)) + vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1") + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == relyt_module.VectorType.RELYT + assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt" + assert vector.embedding_dimension == 2 + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock()) + + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + session.execute.assert_not_called() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] + assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) + assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql) + assert any("CREATE INDEX" in sql for sql in executed_sql) + relyt_module.redis_client.set.assert_called_once() + + +def test_add_texts_and_metadata_queries(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector._group_id = "group-1" + vector.client = MagicMock() + + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + + monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "d-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert conn.execute.call_count >= 1 + first_insert_values = conn.execute.call_args.args[0].compile().params + assert "group_id" in str(first_insert_values) + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"] + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +# 1. delete_by_uuids: success and connect error +def test_delete_by_uuids_success_and_connect_error(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + with pytest.raises(ValueError, match="No ids provided"): + vector.delete_by_uuids(None) + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + assert vector.delete_by_uuids(["id-1"]) is True + vector.client.connect.side_effect = RuntimeError("boom") + assert vector.delete_by_uuids(["id-1"]) is False + + +# 2. delete_by_metadata_field calls delete_by_uuids +def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_uuids.assert_called_once_with(["id-1"]) + + +# 3. delete_by_ids translates to uuids +def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_ids(["doc-1", "doc-2"]) + vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"]) + + +# 4. text_exists True +def test_text_exists_true(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is True + + +# 5. text_exists False +def test_text_exists_false(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is False + + +# 6. similarity_search_with_score_by_vector returns Documents and scores +def test_similarity_search_with_score_by_vector(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + result_rows = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8), + ] + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = result_rows + vector.client.connect.return_value = conn + similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]}) + assert len(similarities) == 2 + assert similarities[0][0].page_content == "doc-a" + + +# 7. search_by_vector filters by score and ids +def test_search_by_vector_filters_by_score_and_ids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.similarity_search_with_score_by_vector = MagicMock( + return_value=[ + (Document(page_content="a", metadata={"doc_id": "1"}), 0.1), + (Document(page_content="b", metadata={}), 0.9), + ] + ) + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert vector.search_by_full_text("query") == [] + + +# 8. delete commits session +def test_delete_commits_session(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete() + session.commit.assert_called_once() + + +def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): + factory = relyt_module.RelytVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432) + monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt") + + with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py new file mode 100644 index 0000000000..e3b6676d9b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py @@ -0,0 +1,316 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_tablestore_module(): + tablestore = types.ModuleType("tablestore") + + class _BatchGetRowRequest: + def __init__(self): + self.items = [] + + def add(self, item): + self.items.append(item) + + class _TableInBatchGetRowItem: + def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver): + self.table_name = table_name + self.rows_to_get = rows_to_get + self.columns_to_get = columns_to_get + + class _Row: + def __init__(self, primary_key, attribute_columns=None): + self.primary_key = primary_key + self.attribute_columns = attribute_columns or [] + + class _Client: + def __init__(self, *_args): + self.list_table = MagicMock(return_value=[]) + self.create_table = MagicMock() + self.list_search_index = MagicMock(return_value=[]) + self.create_search_index = MagicMock() + self.delete_search_index = MagicMock() + self.delete_table = MagicMock() + self.put_row = MagicMock() + self.delete_row = MagicMock() + self.get_row = MagicMock(return_value=(None, None, None)) + self.batch_get_row = MagicMock() + self.search = MagicMock() + + tablestore.OTSClient = _Client + tablestore.BatchGetRowRequest = _BatchGetRowRequest + tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem + tablestore.Row = _Row + tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema) + tablestore.TableOptions = lambda: ("table_options",) + tablestore.CapacityUnit = lambda read, write: ("capacity", read, write) + tablestore.ReservedThroughput = lambda cap: ("reserved", cap) + tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs) + tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs) + tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas) + tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs) + tablestore.TermQuery = lambda key, value: ("term_query", key, value) + tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs) + tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.TermsQuery = lambda key, values: ("terms_query", key, values) + tablestore.Sort = lambda **kwargs: ("sort", kwargs) + tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs) + tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs) + + tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD") + tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD") + tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32") + tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE") + tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX") + tablestore.SortOrder = SimpleNamespace(DESC="DESC") + return tablestore + + +@pytest.fixture +def tablestore_module(monkeypatch): + fake_module = _build_fake_tablestore_module() + monkeypatch.setitem(sys.modules, "tablestore", fake_module) + + import core.rag.datasource.vdb.tablestore.tablestore_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "access_key_id": "ak", + "access_key_secret": "sk", + "instance_name": "instance", + "endpoint": "endpoint", + "normalize_full_text_bm25_score": False, + } + values.update(overrides) + return module.TableStoreConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("access_key_id", "", "config ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config ACCESS_KEY_SECRET is required"), + ("instance_name", "", "config INSTANCE_NAME is required"), + ("endpoint", "", "config ENDPOINT is required"), + ], +) +def test_tablestore_config_validation(tablestore_module, field, value, message): + values = _config(tablestore_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + tablestore_module.TableStoreConfig.model_validate(values) + + +def test_init_and_basic_delegation(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + assert vector.get_type() == tablestore_module.VectorType.TABLESTORE + assert vector._table_name == "collection_1" + assert vector._index_name == "collection_1_idx" + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]]) + + vector.create_collection([[0.1, 0.2]]) + assert vector._create_collection.call_count == 2 + + +def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + # get_by_ids + ok_item = SimpleNamespace( + is_ok=True, + row=SimpleNamespace( + attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)] + ), + ) + fail_item = SimpleNamespace(is_ok=False, row=None) + batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item]) + vector._tablestore_client.batch_get_row.return_value = batch_resp + docs = vector.get_by_ids(["id-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + + # text_exists + vector._tablestore_client.get_row.return_value = (None, object(), None) + assert vector.text_exists("id-1") is True + vector._tablestore_client.get_row.return_value = (None, None, None) + assert vector.text_exists("id-1") is False + + # delete wrappers + vector._delete_row = MagicMock() + vector.delete_by_ids([]) + vector._delete_row.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._delete_row.call_count == 2 + + vector._search_by_metadata = MagicMock(return_value=["id-a"]) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"] + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-a"]) + + vector._search_by_vector = MagicMock(return_value=["vec-doc"]) + vector._search_by_full_text = MagicMock(return_value=["fts-doc"]) + assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"] + assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"] + + vector._delete_table_if_exist = MagicMock() + vector.delete() + vector._delete_table_if_exist.assert_called_once() + + +def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table_if_not_exist = MagicMock() + vector._create_search_index_if_not_exist = MagicMock() + vector._create_collection(3) + vector._create_table_if_not_exist.assert_not_called() + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(3) + vector._create_table_if_not_exist.assert_called_once() + vector._create_search_index_if_not_exist.assert_called_once_with(3) + tablestore_module.redis_client.set.assert_called_once() + + vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module)) + vector._tablestore_client.list_table.return_value = ["collection_2"] + assert vector._create_table_if_not_exist() is None + vector._tablestore_client.list_table.return_value = [] + vector._create_table_if_not_exist() + vector._tablestore_client.create_table.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")] + assert vector._create_search_index_if_not_exist(3) is None + vector._tablestore_client.list_search_index.return_value = [] + vector._create_search_index_if_not_exist(3) + vector._tablestore_client.create_search_index.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")] + vector._delete_table_if_exist() + assert vector._tablestore_client.delete_search_index.call_count == 2 + vector._tablestore_client.delete_table.assert_called_once_with("collection_2") + + vector._delete_search_index() + vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx") + + +def test_write_row_and_search_helpers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + vector._write_row( + "id-1", + { + "page_content": "hello", + "vector": [0.1, 0.2], + "metadata": {"doc_id": "d-1", "document_id": "doc-1"}, + }, + ) + put_row_call = vector._tablestore_client.put_row.call_args + assert put_row_call.args[0] == "collection_1" + attrs = put_row_call.args[1].attribute_columns + assert any(item[0] == "metadata_tags" for item in attrs) + + vector._delete_row("id-1") + vector._tablestore_client.delete_row.assert_called_once() + + # metadata search pagination + first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next") + second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"") + vector._tablestore_client.search.side_effect = [first_page, second_page] + ids = vector._search_by_metadata("document_id", "doc-1") + assert ids == ["row-1", "row-2"] + vector._tablestore_client.search.side_effect = None + + # vector search + hit1 = SimpleNamespace( + score=0.9, + row=( + None, + [("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))], + ), + ) + hit2 = SimpleNamespace( + score=0.2, + row=( + None, + [("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))], + ), + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2]) + docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0) + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0 + + # full text search with and without normalized score filter + vector._normalize_full_text_bm25_score = True + hit3 = SimpleNamespace( + score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))]) + ) + hit4 = SimpleNamespace( + score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))]) + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4]) + docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2) + assert len(docs) == 1 + assert "score" in docs[0].metadata + + vector._normalize_full_text_bm25_score = False + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3]) + docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0) + assert len(docs) == 1 + assert "score" not in docs[0].metadata + + +def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch): + factory = tablestore_module.TableStoreVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True) + + with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py new file mode 100644 index 0000000000..d8f35a6019 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py @@ -0,0 +1,309 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_tencent_modules(): + tcvdb_text = types.ModuleType("tcvdb_text") + tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder") + tcvectordb = types.ModuleType("tcvectordb") + tcvectordb_model = types.ModuleType("tcvectordb.model") + tcvectordb_document = types.ModuleType("tcvectordb.model.document") + tcvectordb_index = types.ModuleType("tcvectordb.model.index") + tcvectordb_enum = types.ModuleType("tcvectordb.model.enum") + + class _BM25Encoder: + def encode_texts(self, text): + return {"encoded_text": text} + + def encode_queries(self, query): + return {"encoded_query": query} + + @classmethod + def default(cls, _lang): + return cls() + + class VectorDBError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class RPCVectorDBClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_database_if_not_exists = MagicMock() + self.exists_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[])) + self.create_collection = MagicMock() + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.search = MagicMock(return_value=[]) + self.hybrid_search = MagicMock(return_value=[]) + self.drop_collection = MagicMock() + + class _Document: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class _HNSWSearchParams: + def __init__(self, ef): + self.ef = ef + + class _AnnSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _KeywordSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _WeightedRerank: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Filter: + @staticmethod + def in_(field, values): + return ("in", field, values) + + def __init__(self, condition): + self.condition = condition + + _Filter.In = staticmethod(_Filter.in_) + + class _HNSWParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _FilterIndex: + def __init__(self, *args): + self.args = args + + class _VectorIndex: + def __init__(self, *args): + self.args = args + + class _SparseIndex: + def __init__(self, **kwargs): + self.kwargs = kwargs + + tcvectordb_enum.IndexType = SimpleNamespace( + __members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"}, + PRIMARY_KEY="PRIMARY_KEY", + FILTER="FILTER", + SPARSE_INVERTED="SPARSE", + ) + tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP") + tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector") + + tcvectordb_document.Document = _Document + tcvectordb_document.HNSWSearchParams = _HNSWSearchParams + tcvectordb_document.AnnSearch = _AnnSearch + tcvectordb_document.Filter = _Filter + tcvectordb_document.KeywordSearch = _KeywordSearch + tcvectordb_document.WeightedRerank = _WeightedRerank + + tcvectordb_index.HNSWParams = _HNSWParams + tcvectordb_index.FilterIndex = _FilterIndex + tcvectordb_index.VectorIndex = _VectorIndex + tcvectordb_index.SparseIndex = _SparseIndex + + tcvdb_text_encoder.BM25Encoder = _BM25Encoder + + tcvectordb_model.document = tcvectordb_document + tcvectordb_model.enum = tcvectordb_enum + tcvectordb_model.index = tcvectordb_index + + tcvectordb.RPCVectorDBClient = RPCVectorDBClient + tcvectordb.VectorDBException = VectorDBError + + return { + "tcvdb_text": tcvdb_text, + "tcvdb_text.encoder": tcvdb_text_encoder, + "tcvectordb": tcvectordb, + "tcvectordb.model": tcvectordb_model, + "tcvectordb.model.document": tcvectordb_document, + "tcvectordb.model.index": tcvectordb_index, + "tcvectordb.model.enum": tcvectordb_enum, + } + + +@pytest.fixture +def tencent_module(monkeypatch): + for name, module in _build_fake_tencent_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.tencent.tencent_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "url": "http://vdb.local", + "api_key": "api-key", + "timeout": 30, + "username": "user", + "database": "db", + "index_type": "HNSW", + "metric_type": "IP", + "shard": 1, + "replicas": 2, + "max_upsert_batch_size": 2, + "enable_hybrid_search": False, + } + values.update(overrides) + return module.TencentConfig.model_validate(values) + + +def test_config_and_init_paths(tencent_module): + config = _config(tencent_module) + assert config.to_tencent_params()["url"] == "http://vdb.local" + + vector = tencent_module.TencentVector("collection_1", config) + assert vector.get_type() == tencent_module.VectorType.TENCENT + assert vector._client.kwargs["key"] == "api-key" + + vector._client.exists_collection.return_value = True + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)] + ) + vector._client_config.enable_hybrid_search = True + vector._load_collection() + assert vector._enable_hybrid_search is True + assert vector._dimension == 768 + + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=512)] + ) + vector._load_collection() + assert vector._enable_hybrid_search is False + + +def test_create_collection_branches(tencent_module, monkeypatch): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.exists_collection.return_value = True + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + vector._client.exists_collection.return_value = False + vector._client_config.index_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_collection(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_collection(3) + + vector._client_config.metric_type = "IP" + vector._client.create_collection.side_effect = [ + tencent_module.VectorDBException("fieldType:json unsupported"), + None, + ] + vector._enable_hybrid_search = True + vector._create_collection(3) + assert vector._client.create_collection.call_count == 2 + tencent_module.redis_client.set.assert_called_once() + vector._client.create_collection.side_effect = None + + +def test_create_add_delete_and_search_behaviour(tencent_module): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True)) + vector._create_collection = MagicMock() + docs = [ + Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}), + Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}), + Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + vector.create(docs, embeddings) + vector._create_collection.assert_called_once_with(1) + + vector._client.upsert.reset_mock() + vector.add_texts(docs, embeddings) + assert vector._client.upsert.call_count == 2 + first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"] + assert "sparse_vector" in first_docs[0].__dict__ + + vector._client.query.return_value = [{"id": "a"}] + assert vector.text_exists("a") is True + vector._client.query.return_value = [] + assert vector.text_exists("a") is False + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["a", "b", "c"]) + assert vector._client.delete.call_count == 2 + vector.delete_by_metadata_field("document_id", "doc-a") + assert vector._client.delete.call_count >= 3 + + vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]] + vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(vec_docs) == 1 + assert vec_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._enable_hybrid_search = False + assert vector.search_by_full_text("query") == [] + vector._enable_hybrid_search = True + vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]] + fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(fts_docs) == 1 + + # _get_search_res handles old string metadata format + compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5) + assert len(compat_docs) == 1 + assert compat_docs[0].metadata["score"] == pytest.approx(0.8) + + vector._has_collection = MagicMock(return_value=True) + vector.delete() + vector._client.drop_collection.assert_called_once() + + +def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch): + factory = tencent_module.TencentVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True) + + with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py new file mode 100644 index 0000000000..369cda39bf --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + + +class _DummyVector(BaseVector): + def __init__(self, collection_name: str, existing_ids: set[str] | None = None): + super().__init__(collection_name) + self._existing_ids = existing_ids or set() + + def get_type(self) -> str: + return "dummy" + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete_by_metadata_field(self, key: str, value: str): + return None + + def search_by_vector(self, query_vector: list[float], **kwargs): + return [] + + def search_by_full_text(self, query: str, **kwargs): + return [] + + def delete(self): + return None + + +@pytest.mark.parametrize( + ("base_method", "args"), + [ + (BaseVector.get_type, ()), + (BaseVector.create, ([], [])), + (BaseVector.add_texts, ([], [])), + (BaseVector.text_exists, ("doc-1",)), + (BaseVector.delete_by_ids, ([],)), + (BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.search_by_vector, ([0.1],)), + (BaseVector.search_by_full_text, ("query",)), + (BaseVector.delete, ()), + ], +) +def test_base_vector_default_methods_raise_not_implemented(base_method, args): + vector = _DummyVector("collection_1") + + with pytest.raises(NotImplementedError): + base_method(vector, *args) + + +def test_filter_duplicate_texts_removes_existing_docs(): + vector = _DummyVector("collection_1", existing_ids={"dup"}) + docs = [ + SimpleNamespace(page_content="keep-no-meta", metadata=None), + Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}), + Document(page_content="remove-dup", metadata={"doc_id": "dup"}), + Document(page_content="keep-unique", metadata={"doc_id": "unique"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + + assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"] + + +def test_get_uuids_and_collection_name_property(): + vector = _DummyVector("collection_1") + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + Document(page_content="c", metadata={"document_id": "d-1"}), + Document(page_content="d", metadata={"doc_id": "id-2"}), + ] + + assert vector._get_uuids(docs) == ["id-1", "id-2"] + assert vector.collection_name == "collection_1" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py new file mode 100644 index 0000000000..dd536af759 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -0,0 +1,434 @@ +import base64 +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str): + fake_module = types.ModuleType(module_path) + fake_cls = type(class_name, (), {}) + setattr(fake_module, class_name, fake_cls) + monkeypatch.setitem(sys.modules, module_path, fake_module) + return fake_cls + + +@pytest.fixture +def vector_factory_module(): + import importlib + + import core.rag.datasource.vdb.vector_factory as module + + return importlib.reload(module) + + +def test_gen_index_struct_dict(vector_factory_module): + result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict( + vector_factory_module.VectorType.WEAVIATE, + "collection_1", + ) + + assert result == { + "type": vector_factory_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "collection_1"}, + } + + +@pytest.mark.parametrize( + ("vector_type", "module_path", "class_name"), + [ + ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ( + "ALIBABACLOUD_MYSQL", + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "AlibabaCloudMySQLVectorFactory", + ), + ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ( + "ELASTICSEARCH", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "ElasticSearchVectorFactory", + ), + ( + "ELASTICSEARCH_JA", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "ElasticSearchJaVectorFactory", + ), + ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ( + "OPENSEARCH", + "core.rag.datasource.vdb.opensearch.opensearch_vector", + "OpenSearchVectorFactory", + ), + ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ( + "TIDB_ON_QDRANT", + "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "TidbOnQdrantVectorFactory", + ), + ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ( + "HUAWEI_CLOUD", + "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "HuaweiCloudVectorFactory", + ), + ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ], +) +def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): + expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name) + + result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type)) + + assert result_cls is expected_cls + + +def test_get_vector_factory_unsupported(vector_factory_module): + with pytest.raises(ValueError, match="not supported"): + vector_factory_module.Vector.get_vector_factory("unknown") + + +def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): + dataset = SimpleNamespace(id="dataset-1") + + with ( + patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), + patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), + ): + default_vector = vector_factory_module.Vector(dataset) + custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) + + assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + assert custom_vector._attributes == ["doc_id"] + assert default_vector._embeddings == "embeddings" + assert default_vector._vector_processor == "processor" + + +def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): + calls = {"vector_type": None, "init_args": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + calls["init_args"] = (dataset, attributes, embeddings) + return "vector-processor" + + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1" + ) + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH + assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings") + + +def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch): + class _Expr: + def __eq__(self, _other): + return "expr" + + calls = {"vector_type": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + return "vector-processor" + + monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))), + ) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True) + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT + + +def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch): + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = [] + vector._embeddings = "embeddings" + + with pytest.raises(ValueError, match="Vector store must be specified"): + vector._init_vector() + + +def test_create_batches_texts_and_skips_empty_input(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)] + vector._embeddings.embed_documents.side_effect = [ + [[0.1] for _ in range(1000)], + [[0.2]], + ] + + vector.create(texts=docs, trace_id="trace-1") + + assert vector._embeddings.embed_documents.call_count == 2 + assert vector._vector_processor.create.call_count == 2 + assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1" + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=None) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): + class _Field: + def in_(self, value): + return value + + def __eq__(self, value): + return value + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]] + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace( + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")]) + ) + ), + ) + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc")) + + docs = [ + Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}), + Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}), + ] + + vector.create_multimodal(file_documents=docs, request_id="r-1") + + file_base64 = base64.b64encode(b"abc").decode() + vector._embeddings.embed_multimodal_documents.assert_called_once_with( + [{"content": file_base64, "content_type": "image", "file_id": "f-1"}] + ) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], + embeddings=[[0.1, 0.2]], + request_id="r-1", + ) + + vector._embeddings.embed_multimodal_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create_multimodal(file_documents=None) + vector._embeddings.embed_multimodal_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_with_optional_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock() + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + vector._filter_duplicate_texts.return_value = [docs[0]] + vector._embeddings.embed_documents.return_value = [[0.1]] + + vector.add_texts(docs, duplicate_check=True, flag=True) + + vector._filter_duplicate_texts.assert_called_once_with(docs) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True + ) + + vector._filter_duplicate_texts.reset_mock() + vector._vector_processor.create.reset_mock() + vector._embeddings.embed_documents.return_value = [[0.2], [0.3]] + + vector.add_texts(docs, duplicate_check=False) + + vector._filter_duplicate_texts.assert_not_called() + vector._vector_processor.create.assert_called_once() + + +def test_vector_delegation_methods(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_query.return_value = [0.1, 0.2] + vector._vector_processor = MagicMock() + vector._vector_processor.text_exists.return_value = True + vector._vector_processor.search_by_vector.return_value = ["vector-doc"] + vector._vector_processor.search_by_full_text.return_value = ["text-doc"] + + assert vector.text_exists("doc-1") is True + vector.delete_by_ids(["doc-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"] + assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"] + + vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"]) + vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1") + + +def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): + class _Field: + def __eq__(self, value): + return value + + upload_query = MagicMock() + upload_query.where.return_value = upload_query + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr( + vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query)) + ) + + upload_query.first.return_value = None + assert vector.search_by_file("file-1") == [] + + upload_query.first.return_value = SimpleNamespace(key="blob-key") + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) + vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] + vector._vector_processor.search_by_vector.return_value = ["hit"] + + result = vector.search_by_file("file-2", top_k=2) + + assert result == ["hit"] + payload = vector._embeddings.embed_multimodal_query.call_args.args[0] + assert payload["content_type"] == vector_factory_module.DocType.IMAGE + assert payload["file_id"] == "file-2" + + +def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch): + delete_mock = MagicMock() + redis_delete = MagicMock() + monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1") + + vector.delete() + + delete_mock.assert_called_once() + redis_delete.assert_called_once_with("vector_indexing_collection_1") + + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="") + redis_delete.reset_mock() + vector.delete() + redis_delete.assert_not_called() + + +def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch): + model_manager = MagicMock() + model_manager.get_model_instance.return_value = "model-instance" + + monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + result = vector._get_embeddings() + + assert result == "cached-embedding" + model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=vector_factory_module.ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + +def test_filter_duplicate_texts_and_getattr(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup") + + docs = [ + SimpleNamespace(page_content="no-meta", metadata=None), + Document(page_content="empty-doc-id", metadata={"doc_id": ""}), + Document(page_content="duplicate", metadata={"doc_id": "dup"}), + Document(page_content="unique", metadata={"doc_id": "ok"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"] + + class _Processor: + def ping(self): + return "pong" + + vector._vector_processor = _Processor() + assert vector.ping() == "pong" + + with pytest.raises(AttributeError): + _ = vector.unknown_method + + vector._vector_processor = None + with pytest.raises(AttributeError, match="vector_processor"): + _ = vector.another_missing diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000..951a920f3b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,443 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +@pytest.fixture +def tidb_module(): + import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + + return importlib.reload(module) + + +def _config(tidb_module): + return tidb_module.TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="secret", + database="dify", + program_name="dify-app", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config TIDB_VECTOR_HOST is required"), + ("port", 0, "config TIDB_VECTOR_PORT is required"), + ("user", "", "config TIDB_VECTOR_USER is required"), + ("database", "", "config TIDB_VECTOR_DATABASE is required"), + ("program_name", "", "config APPLICATION_NAME is required"), + ], +) +def test_tidb_config_validation(tidb_module, field, value, message): + values = _config(tidb_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + tidb_module.TiDBVectorConfig.model_validate(values) + + +def test_init_get_type_and_distance_func(tidb_module, monkeypatch): + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine")) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2") + + assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR + assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify") + assert vector._dimension == 1536 + assert vector._get_distance_func() == "VEC_L2_DISTANCE" + + vector._distance_func = "cosine" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + vector._distance_func = "other" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + +def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch): + fake_tidb_vector = types.ModuleType("tidb_vector") + fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy") + + class _VectorType: + def __init__(self, dim): + self.dim = dim + + fake_tidb_sqlalchemy.VectorType = _VectorType + + monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector) + monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy) + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock())) + monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr( + tidb_module, + "Table", + lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns), + ) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module)) + table = vector._table(3) + + assert table.name == "collection_1" + column_names = [column.args[0] for column in table.columns] + assert tidb_module.Field.PRIMARY_KEY in column_names + assert tidb_module.Field.VECTOR in column_names + assert tidb_module.Field.TEXT_KEY in column_names + + +def test_create_calls_collection_and_add_texts(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + assert vector._dimension == 2 + + +def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + + tidb_module.Session = MagicMock() + + vector._create_collection(3) + + tidb_module.Session.assert_not_called() + tidb_module.redis_client.set.assert_not_called() + + +def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "l2" + + vector._create_collection(3) + + session.begin.assert_called_once() + sql = str(session.execute.call_args.args[0]) + assert "VECTOR(3)" in sql + assert "VEC_L2_DISTANCE" in sql + session.commit.assert_called_once() + tidb_module.redis_client.set.assert_called_once() + + +def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch): + class _InsertStmt: + def __init__(self, table): + self.table = table + + def values(self, rows): + return {"table": self.table, "rows": rows} + + monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table)) + + conn = MagicMock() + transaction = MagicMock() + transaction.__enter__.return_value = None + transaction.__exit__.return_value = None + conn.begin.return_value = transaction + + connection_ctx = MagicMock() + connection_ctx.__enter__.return_value = conn + connection_ctx.__exit__.return_value = None + + engine = MagicMock() + engine.connect.return_value = connection_ctx + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._engine = engine + vector._table = MagicMock(return_value="table") + + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)] + embeddings = [[float(i)] for i in range(501)] + + ids = vector.add_texts(docs, embeddings) + + assert ids[0] == "id-0" + assert len(ids) == 501 + assert conn.execute.call_count == 2 + + +@pytest.fixture +def tidb_vector_with_session(tidb_module, monkeypatch): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + return vector, session, tidb_module + + +# 1. search_by_full_text returns empty +def test_search_by_full_text_returns_empty(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + assert vector.search_by_full_text("query") == [] + + +# 2. text_exists returns True when ids found +def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + + +# 3. text_exists returns False when no ids +def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=None) + assert vector.text_exists("doc-1") is False + + +# 4. delete_by_ids delegates to _delete_by_ids when ids found +def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"]) + + +# 5. delete_by_ids skips when no ids found +def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-c"]) + vector._delete_by_ids.assert_not_called() + + +# 6. get_ids_by_metadata_field returns ids and returns None +def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + # Returns ids + session.execute.return_value.fetchall.return_value = [("id-1",)] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"] + # Returns None + session.execute.return_value.fetchall.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None + + +# 1. _delete_by_ids raises on None +def test__delete_by_ids_raises_on_none(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + with pytest.raises(ValueError, match="No ids provided"): + vector._delete_by_ids(None) + + +# 2. _delete_by_ids returns True and calls execute +def test__delete_by_ids_returns_true_and_calls_execute(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-1"]) is True + conn.execute.assert_called_once() + + +# 3. _delete_by_ids returns False on RuntimeError +def test__delete_by_ids_returns_false_on_runtime_error(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + conn.execute.side_effect = RuntimeError("delete failed") + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-2"]) is False + + +# 4. delete_by_metadata_field calls _delete_by_ids when ids found +def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-3") + vector._delete_by_ids.assert_called_once_with(["id-3"]) + + +# 5. delete_by_metadata_field does nothing when no ids +def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-4") + vector._delete_by_ids.assert_not_called() + + +# Test search_by_vector filters and scores +def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4), + ] + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params["distance"] == pytest.approx(0.5) + assert params["top_k"] == 2 + session.commit.assert_not_called() + + +# Test delete drops table +def test_delete_drops_table(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = None + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector.delete() + drop_sql = str(session.execute.call_args.args[0]) + assert "DROP TABLE IF EXISTS collection_1" in drop_sql + session.commit.assert_called_once() + + +def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): + factory = tidb_module.TiDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000) + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") + + with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000..ac8a63a44b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,186 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_upstash_module(): + upstash_module = types.ModuleType("upstash_vector") + + class Vector: + def __init__(self, id, vector, metadata, data): + self.id = id + self.vector = vector + self.metadata = metadata + self.data = data + + class Index: + def __init__(self, url, token): + self.url = url + self.token = token + self.info = MagicMock(return_value=SimpleNamespace(dimension=8)) + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.reset = MagicMock() + + upstash_module.Vector = Vector + upstash_module.Index = Index + return upstash_module + + +@pytest.fixture +def upstash_module(monkeypatch): + # Remove patched modules if present + for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + if modname in sys.modules: + monkeypatch.delitem(sys.modules, modname, raising=False) + monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) + module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + return module + + +def _config(module): + return module.UpstashVectorConfig(url="https://upstash.example", token="token-123") + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("url", "", "Upstash URL is required"), + ("token", "", "Upstash Token is required"), + ], +) +def test_upstash_config_validation(upstash_module, field, value, message): + values = _config(upstash_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + upstash_module.UpstashVectorConfig.model_validate(values) + + +def test_init_get_type_and_dimension(upstash_module, monkeypatch): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + assert vector.get_type() == upstash_module.VectorType.UPSTASH + assert vector._table_name == "collection_1" + assert vector._get_index_dimension() == 8 + + vector.index.info.return_value = SimpleNamespace(dimension=None) + assert vector._get_index_dimension() == 1536 + + vector.index.info.return_value = None + assert vector._get_index_dimension() == 1536 + + monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid") + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.add_texts(docs, [[0.1, 0.2]]) + + vector.index.upsert.assert_called_once() + upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"] + assert upsert_vectors[0].id == "generated-uuid" + + +def test_create_text_exists_and_delete_by_ids(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + vector.get_ids_by_metadata_field.return_value = [] + assert vector.text_exists("doc-1") is False + + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]]) + vector._delete_by_ids = MagicMock() + vector.delete_by_ids(["doc-1", "doc-2", "doc-3"]) + vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"]) + + +def test_delete_helpers_and_search(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + vector._delete_by_ids([]) + vector.index.delete.assert_not_called() + vector._delete_by_ids(["a", "b"]) + vector.index.delete.assert_called_once_with(ids=["a", "b"]) + + vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")] + ids = vector.get_ids_by_metadata_field("doc_id", "doc-1") + assert ids == ["x-1", "x-2"] + query_kwargs = vector.index.query.call_args.kwargs + assert query_kwargs["top_k"] == 1000 + assert query_kwargs["filter"] == "doc_id = 'doc-1'" + + vector._delete_by_ids = MagicMock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + vector._delete_by_ids.assert_called_once_with(["x-1"]) + + vector._delete_by_ids.reset_mock() + vector.get_ids_by_metadata_field.return_value = [] + vector.delete_by_metadata_field("doc_id", "doc-2") + vector._delete_by_ids.assert_not_called() + + +def test_search_by_vector_filter_threshold_and_delete(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.index.query.return_value = [ + SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9), + SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3), + SimpleNamespace(metadata=None, data="text-3", score=0.99), + SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=3, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + search_kwargs = vector.index.query.call_args.kwargs + assert search_kwargs["top_k"] == 3 + assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')" + + assert vector.search_by_full_text("query") == [] + + vector.delete() + vector.index.reset.assert_called_once() + + +def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch): + factory = upstash_module.UpstashVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123") + + with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py new file mode 100644 index 0000000000..9da92af2d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py @@ -0,0 +1,310 @@ +import importlib +import json +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_vikingdb_modules(): + volcengine = types.ModuleType("volcengine") + volcengine.__path__ = [] + viking_db = types.ModuleType("volcengine.viking_db") + + class Data(UserDict): + def __init__(self, payload): + super().__init__(payload) + self.fields = payload + + class DistanceType: + L2 = "L2" + + class IndexType: + HNSW = "HNSW" + + class QuantType: + Float = "Float" + + class FieldType: + String = "string" + Text = "text" + Vector = "vector" + + class Field: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Collection: + def __init__(self): + self.upsert_data = MagicMock() + self.fetch_data = MagicMock(return_value=None) + self.delete_data = MagicMock() + + class _Index: + def __init__(self): + self.search = MagicMock(return_value=[]) + self.search_by_vector = MagicMock(return_value=[]) + + class VikingDBService: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_collection = MagicMock() + self.create_index = MagicMock() + self.drop_index = MagicMock() + self.drop_collection = MagicMock() + self._collection = _Collection() + self._index = _Index() + self.get_collection = MagicMock(return_value=self._collection) + self.get_index = MagicMock(return_value=self._index) + + viking_db.Data = Data + viking_db.DistanceType = DistanceType + viking_db.Field = Field + viking_db.FieldType = FieldType + viking_db.IndexType = IndexType + viking_db.QuantType = QuantType + viking_db.VectorIndexParams = VectorIndexParams + viking_db.VikingDBService = VikingDBService + + return {"volcengine": volcengine, "volcengine.viking_db": viking_db} + + +@pytest.fixture +def vikingdb_module(monkeypatch): + for name, module in _build_fake_vikingdb_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VikingDBConfig( + access_key="ak", + secret_key="sk", + host="host", + region="region", + scheme="https", + connection_timeout=10, + socket_timeout=20, + ) + + +def test_init_get_type_and_has_checks(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB + assert vector._index_name == "collection_1_idx" + + assert vector._has_collection() is True + assert vector._has_index() is True + + vector._client.get_collection.side_effect = RuntimeError("missing") + assert vector._has_collection() is False + vector._client.get_collection.side_effect = None + + vector._client.get_index.side_effect = RuntimeError("missing") + assert vector._has_index() is False + + +def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock()) + + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + vector._client.create_index.assert_not_called() + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None)) + vector._has_collection = MagicMock(return_value=False) + vector._has_index = MagicMock(return_value=False) + vector._create_collection(4) + + vector._client.create_collection.assert_called_once() + vector._client.create_index.assert_called_once() + vikingdb_module.redis_client.set.assert_called_once() + + +def test_create_and_add_texts(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module)) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}), + ] + vector.add_texts(docs, [[0.1], [0.2]]) + + vector._client.get_collection.assert_called() + upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0] + assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a" + assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2" + + +def test_text_exists_and_delete_operations(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"}) + assert vector.text_exists("id-1") is True + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace( + fields={"message": "data does not exist"} + ) + assert vector.text_exists("id-1") is False + + vector._client.get_collection.return_value.fetch_data.return_value = None + assert vector.text_exists("id-1") is False + + vector.delete_by_ids(["id-1"]) + vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-2"]) + + +def test_get_ids_and_search_helpers(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_index.return_value.search.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "x") == [] + + vector._client.get_index.return_value.search.return_value = [ + SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}), + SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}), + SimpleNamespace(id="c", fields={}), + ] + assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"] + + empty_docs = vector._get_search_res([], score_threshold=0.1) + assert empty_docs == [] + + results = [ + SimpleNamespace( + id="a", + score=0.3, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}), + }, + ), + SimpleNamespace( + id="b", + score=0.9, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}), + }, + ), + ] + + docs = vector._get_search_res(results, score_threshold=0.2) + assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"] + + vector._client.get_index.return_value.search_by_vector.return_value = results + filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"]) + assert len(filtered_docs) == 1 + assert filtered_docs[0].page_content == "doc-b" + assert vector.search_by_full_text("query") == [] + + +def test_delete_drops_index_and_collection_when_present(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._has_index = MagicMock(return_value=True) + vector._has_collection = MagicMock(return_value=True) + + vector.delete() + + vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx") + vector._client.drop_collection.assert_called_once_with("collection_1") + + vector._client.drop_index.reset_mock() + vector._client.drop_collection.reset_mock() + vector._has_index.return_value = False + vector._has_collection.return_value = False + vector.delete() + + vector._client.drop_index.assert_not_called() + vector._client.drop_collection.assert_not_called() + + +def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch): + factory = vikingdb_module.VikingDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls: + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10) + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20) + + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +@pytest.mark.parametrize( + ("field", "message"), + [ + ("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"), + ("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"), + ("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"), + ("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"), + ("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"), + ], +) +def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message): + factory = vikingdb_module.VikingDBVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None + ) + + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, field, None) + + with pytest.raises(ValueError, match=message): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py index 3bd656ba84..69d1833001 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in: - Full-text search result metadata (search_by_full_text) """ +import datetime +import json import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase): def tearDown(self): weaviate_vector_module._weaviate_client = None + def test_config_requires_endpoint(self): + with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): + WeaviateConfig(endpoint="") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" @@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase): ) return wv, mock_client + def test_shutdown_client_logs_debug_when_close_fails(self): + mock_client = MagicMock() + mock_client.close.side_effect = RuntimeError("close failed") + weaviate_vector_module._weaviate_client = mock_client + + with patch.object(weaviate_vector_module.logger, "debug") as mock_debug: + weaviate_vector_module._shutdown_weaviate_client() + + assert weaviate_vector_module._weaviate_client is None + mock_client.close.assert_called_once() + mock_debug.assert_called_once() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.return_value = True + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.side_effect = [False, True] + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + config = WeaviateConfig( + endpoint="https://weaviate.example.com", + grpc_endpoint="grpc.example.com:6000", + api_key="test-key", + batch_size=50, + ) + + client = wv._init_client(config) + + assert client is mock_client + assert mock_connect.call_args.kwargs == { + "http_host": "weaviate.example.com", + "http_port": 443, + "http_secure": True, + "grpc_host": "grpc.example.com", + "grpc_port": 6000, + "grpc_secure": False, + "auth_credentials": "auth-token", + "skip_init_checks": True, + } + mock_api_key.assert_called_once_with("test-key") + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_raises_when_database_not_ready(self, mock_connect): + mock_client = MagicMock() + mock_client.is_ready.return_value = False + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + + with pytest.raises(ConnectionError, match="Vector database is not ready"): + wv._init_client(self.config) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" @@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase): assert wv._collection_name == self.collection_name assert "doc_type" in wv._attributes + def test_get_type_and_to_index_struct(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + + assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE + assert wv.to_index_struct() == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": self.collection_name}, + } + + def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self): + dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1") + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv.get_collection_name(dataset) == "ExistingCollection_Node" + + def test_get_collection_name_generates_name_from_dataset_id(self): + dataset = SimpleNamespace(index_struct_dict=None, id="ds-2") + wv = WeaviateVector.__new__(WeaviateVector) + + with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"): + assert wv.get_collection_name(dataset) == "Generated_Node" + + def test_create_calls_collection_setup_then_add_texts(self): + doc = Document(page_content="hello", metadata={}) + wv = WeaviateVector.__new__(WeaviateVector) + wv._create_collection = MagicMock() + wv.add_texts = MagicMock() + + wv.create([doc], [[0.1, 0.2]]) + + wv._create_collection.assert_called_once() + wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._ensure_properties = MagicMock() + + wv._create_collection() + + wv._client.collections.exists.assert_not_called() + wv._ensure_properties.assert_not_called() + mock_redis.set.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_logs_and_reraises_errors(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock(return_value=False) + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = RuntimeError("create failed") + + with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: + with pytest.raises(RuntimeError, match="create failed"): + wv._create_collection() + + mock_exception.assert_called_once() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" @@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [SimpleNamespace(name="text")] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" @@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + mock_col.config.add_property.side_effect = RuntimeError("cannot add") + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: + wv._ensure_properties() + + assert mock_warning.call_count == 4 + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = { + "text": "fallback distance result", + "document_id": "doc-1", + "doc_id": "segment-1", + } + mock_obj.metadata = None + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector( + query_vector=[0.2] * 3, + document_ids_filter=["doc-1"], + top_k=2, + score_threshold=-1, + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.0 + assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_vector(query_vector=[0.1] * 3) == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" @@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"} + mock_obj.vector = [0.3, 0.4] + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].vector == [0.3, 0.4] + assert mock_col.query.bm25.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_full_text(query="missing") == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" @@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + doc = Document(page_content="text", metadata={"created_at": created_at}) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with ( + patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), + patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + ): + ids = wv.add_texts(documents=[doc], embeddings=[[]]) + + assert ids == ["fallback-uuid"] + call_kwargs = mock_batch.add_object.call_args + assert call_kwargs.kwargs["uuid"] == "fallback-uuid" + assert call_kwargs.kwargs["vector"] is None + assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat() + + def test_is_uuid_handles_invalid_values(self): + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True + assert wv._is_uuid("not-a-uuid") is False + + def test_delete_by_metadata_field_returns_when_collection_is_missing(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = False + + wv.delete_by_metadata_field("doc_id", "segment-1") + + wv._client.collections.use.assert_not_called() + + def test_delete_by_metadata_field_deletes_matching_objects(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + + wv.delete_by_metadata_field("doc_id", "segment-1") + + mock_col.data.delete_many.assert_called_once() + + def test_delete_removes_collection_when_present(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + + wv.delete() + + wv._client.collections.delete.assert_called_once_with(self.collection_name) + + def test_text_exists_handles_missing_and_present_documents(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()]) + + assert wv.text_exists("segment-1") is False + assert wv.text_exists("segment-1") is True + + def test_delete_by_ids_handles_missing_collections_and_404s(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None] + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + wv.delete_by_ids(["ignored"]) + wv.delete_by_ids(["missing-id", "ok-id"]) + + assert mock_col.data.delete_by_id.call_count == 2 + + def test_delete_by_ids_reraises_non_404_errors(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): + wv.delete_by_ids(["bad-id"]) + + def test_json_serializable_converts_datetime(self): + wv = WeaviateVector.__new__(WeaviateVector) + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + + assert wv._json_serializable(created_at) == created_at.isoformat() + assert wv._json_serializable("plain") == "plain" + class TestVectorDefaultAttributes(unittest.TestCase): """Tests for Vector class default attributes list.""" @@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase): assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" +class TestWeaviateVectorFactory(unittest.TestCase): + def test_init_vector_uses_existing_dataset_index_struct(self): + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}}, + index_struct=None, + ) + attributes = ["doc_id"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + config = mock_vector.call_args.kwargs["config"] + assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node" + assert mock_vector.call_args.kwargs["attributes"] == attributes + assert config.endpoint == "http://localhost:8080" + assert config.grpc_endpoint == "localhost:50051" + assert config.api_key == "api-key" + assert config.batch_size == 88 + assert dataset.index_struct is None + + def test_init_vector_generates_collection_and_updates_index_struct(self): + dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + attributes = ["doc_id", "doc_type"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100), + patch.object( + weaviate_vector_module.Dataset, + "gen_collection_name_by_id", + return_value="GeneratedCollection_Node", + ), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node" + assert json.loads(dataset.index_struct) == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "GeneratedCollection_Node"}, + } + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e5f92fbed5..b6577daac8 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1164,7 +1164,7 @@ class TestConversationStatusCount: conversation.id = str(uuid4()) # Mock the database query to return no messages - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1189,7 +1189,7 @@ class TestConversationStatusCount: conversation.id = conversation_id # Mock the database query to return no messages with workflow_run_id - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1274,7 +1274,7 @@ class TestConversationStatusCount: return mock_result # Act & Assert - with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): + with patch("models.model.db.session.scalars", side_effect=mock_scalars): result = conversation.status_count # Verify only 2 database queries were made (not N+1) @@ -1337,7 +1337,7 @@ class TestConversationStatusCount: return mock_result # Act - with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): + with patch("models.model.db.session.scalars", side_effect=mock_scalars): result = conversation.status_count # Assert - query should include app_id filter @@ -1382,7 +1382,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: # Mock the messages query def mock_scalars_side_effect(query): mock_result = MagicMock() @@ -1438,7 +1438,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: def mock_scalars_side_effect(query): mock_result = MagicMock() diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py index 27df4556bc..6511385000 100644 --- a/api/tests/unit_tests/services/plugin/test_oauth_service.py +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -13,6 +13,10 @@ import pytest from services.plugin.oauth_service import OAuthProxyService +def _oauth_proxy_setex_calls(redis_client) -> list: + return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")] + + class TestCreateProxyContext: def test_stores_context_in_redis_with_ttl(self): context_id = OAuthProxyService.create_proxy_context( @@ -22,8 +26,9 @@ class TestCreateProxyContext: assert context_id # non-empty UUID string from extensions.ext_redis import redis_client - redis_client.setex.assert_called_once() - call_args = redis_client.setex.call_args + oauth_calls = _oauth_proxy_setex_calls(redis_client) + assert len(oauth_calls) == 1 + call_args = oauth_calls[0] key = call_args[0][0] ttl = call_args[0][1] stored_data = json.loads(call_args[0][2]) diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 4f7d184046..239e51119c 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch def test_import_app_pending_stores_import_info_in_redis(): service = AppDslService(MagicMock()) + app_dsl_service.redis_client.setex.reset_mock() result = service.import_app( account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, @@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch): created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + app_dsl_service.redis_client.delete.reset_mock() result = service.confirm_import(import_id="import-1", account=_account_mock()) assert result.status == ImportStatus.COMPLETED assert result.app_id == "confirmed-app" - app_dsl_service.redis_client.delete.assert_called_once() + app_dsl_service.redis_client.delete.assert_called_once_with( + f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1" + ) def test_confirm_import_exception_returns_failed(monkeypatch): diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d67469105..35b288cf7c 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -405,7 +405,7 @@ class TestAudioServiceTTS: voice="en-US-Neural", ) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" @@ -549,7 +549,7 @@ class TestAudioServiceTTS: with pytest.raises(ValueError, match="Text is required"): AudioService.transcript_tts(app_model=app, text=None) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): """Test that TTS returns None for invalid message ID format.""" # Arrange @@ -564,7 +564,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): """Test that TTS returns None when message doesn't exist.""" # Arrange @@ -585,7 +585,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): """Test that TTS returns None when message answer is empty.""" # Arrange diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index a23c44b26e..3b1c1fcf17 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -313,7 +313,8 @@ class TestEmailDeliveryTestHandler: recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], ) - subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") assert subs["node_title"] == "title" assert subs["form_content"] == "content" diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 4d2d63e501..b09463b1bc 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -316,7 +316,7 @@ class TestTagServiceRetrieval: - get_tags_by_target_id: Get all tags bound to a specific target """ - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_binding_counts(self, mock_db_session, factory): """ Test retrieving tags with their binding counts. @@ -373,7 +373,7 @@ class TestTagServiceRetrieval: # Verify database query was called mock_db_session.query.assert_called_once() - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_keyword_filter(self, mock_db_session, factory): """ Test retrieving tags filtered by keyword (case-insensitive). @@ -427,7 +427,7 @@ class TestTagServiceRetrieval: # 2. Additional WHERE clause for keyword filtering assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): """ Test retrieving target IDs by tag IDs. @@ -483,7 +483,7 @@ class TestTagServiceRetrieval: # Verify both queries were executed assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): """ Test that empty tag_ids returns empty list. @@ -511,7 +511,7 @@ class TestTagServiceRetrieval: assert results == [], "Should return empty list for empty input" mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name(self, mock_db_session, factory): """ Test retrieving tags by name. @@ -553,7 +553,7 @@ class TestTagServiceRetrieval: assert len(results) == 1, "Should find exactly one tag" assert results[0].name == tag_name, "Tag name should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): """ Test that missing tag_type or tag_name returns empty list. @@ -581,7 +581,7 @@ class TestTagServiceRetrieval: # Verify no database queries were executed mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_by_target_id(self, mock_db_session, factory): """ Test retrieving tags associated with a specific target. @@ -654,7 +654,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") @patch("services.tag_service.uuid.uuid4", autospec=True) def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ @@ -743,7 +743,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ Test updating a tag name. @@ -795,7 +795,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_error_for_duplicate_name( self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory ): @@ -827,7 +827,7 @@ class TestTagServiceCRUD: with pytest.raises(ValueError, match="Tag name already exists"): TagService.update_tags(args, tag_id="tag-123") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): """ Test that updating a non-existent tag raises NotFound. @@ -859,7 +859,7 @@ class TestTagServiceCRUD: with pytest.raises(NotFound, match="Tag not found"): TagService.update_tags(args, tag_id="nonexistent") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_binding_count(self, mock_db_session, factory): """ Test getting the count of bindings for a tag. @@ -895,7 +895,7 @@ class TestTagServiceCRUD: # Verify count matches expectation assert result == expected_count, "Binding count should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag(self, mock_db_session, factory): """ Test deleting a tag and its bindings. @@ -951,7 +951,7 @@ class TestTagServiceCRUD: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_raises_not_found(self, mock_db_session, factory): """ Test that deleting a non-existent tag raises NotFound. @@ -999,7 +999,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test creating tag bindings. @@ -1050,7 +1050,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test that saving duplicate bindings is idempotent. @@ -1090,7 +1090,7 @@ class TestTagServiceBindings: mock_db_session.add.assert_not_called(), "Should not create duplicate binding" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): """ Test deleting a tag binding. @@ -1138,7 +1138,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_called_once(), "Should commit transaction" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): """ Test that deleting a non-existent binding is a no-op. @@ -1175,7 +1175,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): """ Test validating that a dataset target exists. @@ -1216,7 +1216,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for dataset" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): """ Test validating that an app target exists. @@ -1257,7 +1257,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for app" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_dataset( self, mock_db_session, mock_current_user, factory ): @@ -1289,7 +1289,7 @@ class TestTagServiceBindings: TagService.check_target_exists("knowledge", "nonexistent") @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): """ Test that missing app raises NotFound. diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 6804ade5aa..027cd3b1ec 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -59,6 +59,11 @@ def mock_redis(): # Redis is already mocked globally in conftest.py # Reset it for each test redis_client.reset_mock() + redis_client.get.reset_mock() + redis_client.setex.reset_mock() + redis_client.delete.reset_mock() + redis_client.lpush.reset_mock() + redis_client.rpop.reset_mock() redis_client.get.return_value = None redis_client.setex.return_value = True redis_client.delete.return_value = True