mirror of
https://github.com/langgenius/dify.git
synced 2026-03-25 08:18:02 +08:00
test: unit test cases for rag.cleaner, rag.data_post_processor and rag.datasource (#32521)
This commit is contained in:
@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
@ -211,3 +211,16 @@ class TestCleanProcessor:
|
||||
text = "[Text with (parens) and symbols](https://example.com)"
|
||||
expected = "[Text with (parens) and symbols](https://example.com)"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_remove_urls_emails_preserves_markdown_image_links(self):
|
||||
"""Remove plain URLs and emails while preserving markdown image links."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
text = "Email test@example.com and remove https://remove.com but keep "
|
||||
result = CleanProcessor.clean(text, process_rule)
|
||||
|
||||
assert result == "Email and remove but keep "
|
||||
|
||||
def test_filter_string_returns_input_text(self):
|
||||
"""Test filter_string passthrough behavior."""
|
||||
processor = CleanProcessor()
|
||||
assert processor.filter_string("raw text") == "raw text"
|
||||
|
||||
@ -0,0 +1,249 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
def _doc(content: str) -> Document:
|
||||
return Document(page_content=content)
|
||||
|
||||
|
||||
class TestDataPostProcessor:
|
||||
def test_init_sets_rerank_and_reorder_runners(self):
|
||||
rerank_runner = object()
|
||||
reorder_runner = object()
|
||||
|
||||
with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock:
|
||||
with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock:
|
||||
processor = DataPostProcessor(
|
||||
tenant_id="tenant-1",
|
||||
reranking_mode=RerankMode.WEIGHTED_SCORE,
|
||||
reranking_model={"config": "value"},
|
||||
weights={"weight": "value"},
|
||||
reorder_enabled=True,
|
||||
)
|
||||
|
||||
assert processor.rerank_runner is rerank_runner
|
||||
assert processor.reorder_runner is reorder_runner
|
||||
rerank_mock.assert_called_once_with(
|
||||
RerankMode.WEIGHTED_SCORE,
|
||||
"tenant-1",
|
||||
{"config": "value"},
|
||||
{"weight": "value"},
|
||||
)
|
||||
reorder_mock.assert_called_once_with(True)
|
||||
|
||||
def test_invoke_applies_rerank_then_reorder(self):
|
||||
original_documents = [_doc("doc-a")]
|
||||
reranked_documents = [_doc("doc-b")]
|
||||
reordered_documents = [_doc("doc-c")]
|
||||
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
processor.rerank_runner = MagicMock()
|
||||
processor.rerank_runner.run.return_value = reranked_documents
|
||||
processor.reorder_runner = MagicMock()
|
||||
processor.reorder_runner.run.return_value = reordered_documents
|
||||
|
||||
result = processor.invoke(
|
||||
query="how to test",
|
||||
documents=original_documents,
|
||||
score_threshold=0.3,
|
||||
top_n=2,
|
||||
user="user-1",
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
)
|
||||
|
||||
processor.rerank_runner.run.assert_called_once_with(
|
||||
"how to test",
|
||||
original_documents,
|
||||
0.3,
|
||||
2,
|
||||
"user-1",
|
||||
QueryType.IMAGE_QUERY,
|
||||
)
|
||||
processor.reorder_runner.run.assert_called_once_with(reranked_documents)
|
||||
assert result == reordered_documents
|
||||
|
||||
def test_invoke_returns_original_documents_when_no_runner_is_configured(self):
|
||||
documents = [_doc("doc-a"), _doc("doc-b")]
|
||||
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
processor.rerank_runner = None
|
||||
processor.reorder_runner = None
|
||||
|
||||
assert processor.invoke(query="query", documents=documents) == documents
|
||||
|
||||
def test_get_rerank_runner_for_weighted_score(self):
|
||||
weights_config = {
|
||||
"vector_setting": {
|
||||
"vector_weight": 0.7,
|
||||
"embedding_provider_name": "provider-x",
|
||||
"embedding_model_name": "embedding-y",
|
||||
},
|
||||
"keyword_setting": {"keyword_weight": 0.3},
|
||||
}
|
||||
expected_runner = object()
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
|
||||
with patch(
|
||||
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
|
||||
return_value=expected_runner,
|
||||
) as factory_mock:
|
||||
result = processor._get_rerank_runner(
|
||||
reranking_mode=RerankMode.WEIGHTED_SCORE,
|
||||
tenant_id="tenant-1",
|
||||
reranking_model=None,
|
||||
weights=weights_config,
|
||||
)
|
||||
|
||||
assert result is expected_runner
|
||||
kwargs = factory_mock.call_args.kwargs
|
||||
assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE
|
||||
assert kwargs["tenant_id"] == "tenant-1"
|
||||
assert kwargs["weights"].vector_setting.vector_weight == 0.7
|
||||
assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x"
|
||||
assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y"
|
||||
assert kwargs["weights"].keyword_setting.keyword_weight == 0.3
|
||||
|
||||
def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
reranking_model = {
|
||||
"reranking_provider_name": "provider-x",
|
||||
"reranking_model_name": "model-y",
|
||||
}
|
||||
|
||||
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock:
|
||||
with patch(
|
||||
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner"
|
||||
) as factory_mock:
|
||||
result = processor._get_rerank_runner(
|
||||
reranking_mode=RerankMode.RERANKING_MODEL,
|
||||
tenant_id="tenant-1",
|
||||
reranking_model=reranking_model,
|
||||
weights=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
model_mock.assert_called_once_with("tenant-1", reranking_model)
|
||||
factory_mock.assert_not_called()
|
||||
|
||||
def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
model_instance = object()
|
||||
expected_runner = object()
|
||||
|
||||
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance):
|
||||
with patch(
|
||||
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
|
||||
return_value=expected_runner,
|
||||
) as factory_mock:
|
||||
result = processor._get_rerank_runner(
|
||||
reranking_mode=RerankMode.RERANKING_MODEL,
|
||||
tenant_id="tenant-1",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "provider-x",
|
||||
"reranking_model_name": "model-y",
|
||||
},
|
||||
weights=None,
|
||||
)
|
||||
|
||||
assert result is expected_runner
|
||||
factory_mock.assert_called_once_with(
|
||||
runner_type=RerankMode.RERANKING_MODEL,
|
||||
rerank_model_instance=model_instance,
|
||||
)
|
||||
|
||||
def test_get_rerank_runner_returns_none_for_unsupported_mode(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
|
||||
assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None
|
||||
assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None
|
||||
|
||||
def test_get_reorder_runner_by_flag(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
|
||||
assert isinstance(processor._get_reorder_runner(True), ReorderRunner)
|
||||
assert processor._get_reorder_runner(False) is None
|
||||
|
||||
def test_get_rerank_model_instance_returns_none_when_config_is_missing(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
assert processor._get_rerank_model_instance("tenant-1", None) is None
|
||||
|
||||
def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
|
||||
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
|
||||
manager_instance = manager_cls.return_value
|
||||
with pytest.raises(KeyError, match="reranking_model_name"):
|
||||
processor._get_rerank_model_instance(
|
||||
tenant_id="tenant-1",
|
||||
reranking_model={"reranking_provider_name": "provider-x"},
|
||||
)
|
||||
|
||||
manager_instance.get_model_instance.assert_not_called()
|
||||
|
||||
def test_get_rerank_model_instance_success(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
model_instance = object()
|
||||
|
||||
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
|
||||
manager_instance = manager_cls.return_value
|
||||
manager_instance.get_model_instance.return_value = model_instance
|
||||
|
||||
result = processor._get_rerank_model_instance(
|
||||
tenant_id="tenant-1",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "provider-x",
|
||||
"reranking_model_name": "reranker-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert result is model_instance
|
||||
manager_instance.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="provider-x",
|
||||
model_type=ModelType.RERANK,
|
||||
model="reranker-1",
|
||||
)
|
||||
|
||||
def test_get_rerank_model_instance_handles_authorization_error(self):
|
||||
processor = DataPostProcessor.__new__(DataPostProcessor)
|
||||
|
||||
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
|
||||
manager_instance = manager_cls.return_value
|
||||
manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized")
|
||||
|
||||
result = processor._get_rerank_model_instance(
|
||||
tenant_id="tenant-1",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "provider-x",
|
||||
"reranking_model_name": "reranker-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestReorderRunner:
|
||||
def test_run_reorders_even_sized_document_list(self):
|
||||
documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")]
|
||||
|
||||
reordered = ReorderRunner().run(documents)
|
||||
|
||||
assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"]
|
||||
|
||||
def test_run_handles_odd_sized_and_empty_document_lists(self):
|
||||
odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")]
|
||||
runner = ReorderRunner()
|
||||
|
||||
odd_reordered = runner.run(odd_documents)
|
||||
|
||||
assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"]
|
||||
assert runner.run([]) == []
|
||||
@ -0,0 +1,414 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.keyword.jieba.jieba as jieba_module
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class _DummyLock:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class _Field:
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return ("eq", self._name, other)
|
||||
|
||||
def in_(self, values):
|
||||
return ("in", self._name, tuple(values))
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self):
|
||||
self.where_calls: list[tuple] = []
|
||||
|
||||
def where(self, *conditions):
|
||||
self.where_calls.append(conditions)
|
||||
return self
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, segments: list[SimpleNamespace]):
|
||||
self._segments = segments
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return self._segments
|
||||
|
||||
|
||||
class _FakeSelect:
|
||||
def __init__(self):
|
||||
self.where_conditions: tuple | None = None
|
||||
|
||||
def where(self, *conditions):
|
||||
self.where_conditions = conditions
|
||||
return self
|
||||
|
||||
|
||||
def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None):
|
||||
return SimpleNamespace(
|
||||
data_source_type=data_source_type,
|
||||
keyword_table_dict=keyword_table_dict,
|
||||
keyword_table="",
|
||||
)
|
||||
|
||||
|
||||
def _dataset(dataset_keyword_table=None, keyword_number=None):
|
||||
return SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
keyword_number=keyword_number,
|
||||
dataset_keyword_table=dataset_keyword_table,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_runtime(monkeypatch):
|
||||
session = MagicMock()
|
||||
db = SimpleNamespace(session=session)
|
||||
storage = MagicMock()
|
||||
lock = MagicMock(return_value=_DummyLock())
|
||||
redis_client = SimpleNamespace(lock=lock)
|
||||
|
||||
monkeypatch.setattr(jieba_module, "db", db)
|
||||
monkeypatch.setattr(jieba_module, "storage", storage)
|
||||
monkeypatch.setattr(jieba_module, "redis_client", redis_client)
|
||||
|
||||
return SimpleNamespace(session=session, storage=storage, lock=lock)
|
||||
|
||||
|
||||
def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime):
|
||||
dataset = _dataset(_dataset_keyword_table(), keyword_number=2)
|
||||
keyword = Jieba(dataset)
|
||||
handler = MagicMock()
|
||||
handler.extract_keywords.return_value = {"kw1", "kw2"}
|
||||
|
||||
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
|
||||
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
result = keyword.create(
|
||||
[
|
||||
Document(page_content="alpha", metadata={"doc_id": "node-1"}),
|
||||
SimpleNamespace(page_content="ignored", metadata=None),
|
||||
]
|
||||
)
|
||||
|
||||
assert result is keyword
|
||||
keyword._update_segment_keywords.assert_called_once()
|
||||
call_args = keyword._update_segment_keywords.call_args.args
|
||||
assert call_args[0] == "dataset-1"
|
||||
assert call_args[1] == "node-1"
|
||||
assert set(call_args[2]) == {"kw1", "kw2"}
|
||||
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
|
||||
assert saved_table["kw1"] == {"node-1"}
|
||||
assert saved_table["kw2"] == {"node-1"}
|
||||
patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600)
|
||||
|
||||
|
||||
def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3))
|
||||
handler = MagicMock()
|
||||
handler.extract_keywords.return_value = {"auto"}
|
||||
|
||||
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
|
||||
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
texts = [
|
||||
Document(page_content="extract-this", metadata={"doc_id": "node-1"}),
|
||||
Document(page_content="use-manual", metadata={"doc_id": "node-2"}),
|
||||
]
|
||||
keyword.add_texts(texts, keywords_list=[[], ["manual"]])
|
||||
|
||||
assert keyword._update_segment_keywords.call_count == 2
|
||||
first_call = keyword._update_segment_keywords.call_args_list[0].args
|
||||
second_call = keyword._update_segment_keywords.call_args_list[1].args
|
||||
assert set(first_call[2]) == {"auto"}
|
||||
assert second_call[2] == ["manual"]
|
||||
keyword._save_dataset_keyword_table.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1))
|
||||
handler = MagicMock()
|
||||
handler.extract_keywords.return_value = {"from-extractor"}
|
||||
|
||||
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
|
||||
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})])
|
||||
|
||||
handler.extract_keywords.assert_called_once_with("content", 1)
|
||||
assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"}
|
||||
|
||||
|
||||
def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
|
||||
assert keyword.text_exists("node-1") is False
|
||||
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
|
||||
assert keyword.text_exists("node-2") is True
|
||||
assert keyword.text_exists("node-x") is False
|
||||
|
||||
|
||||
def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
|
||||
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}}))
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
keyword.delete_by_ids(["node-1"])
|
||||
|
||||
keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"])
|
||||
keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}})
|
||||
|
||||
|
||||
def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock())
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
keyword.delete_by_ids(["node-1"])
|
||||
|
||||
keyword._delete_ids_from_keyword_table.assert_not_called()
|
||||
keyword._save_dataset_keyword_table.assert_called_once_with(None)
|
||||
|
||||
|
||||
def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime):
|
||||
class _FakeDocumentSegment:
|
||||
dataset_id = _Field("dataset_id")
|
||||
index_node_id = _Field("index_node_id")
|
||||
document_id = _Field("document_id")
|
||||
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
query_stmt = _FakeQuery()
|
||||
patched_runtime.session.query.return_value = query_stmt
|
||||
patched_runtime.session.execute.return_value = _FakeExecuteResult(
|
||||
[
|
||||
SimpleNamespace(
|
||||
index_node_id="node-2",
|
||||
content="segment-content",
|
||||
index_node_hash="hash-2",
|
||||
document_id="doc-2",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
|
||||
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
|
||||
|
||||
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
|
||||
|
||||
assert len(query_stmt.where_calls) == 2
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "segment-content"
|
||||
assert documents[0].metadata["doc_id"] == "node-2"
|
||||
assert documents[0].metadata["doc_hash"] == "hash-2"
|
||||
|
||||
|
||||
def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime):
|
||||
db_keyword = _dataset_keyword_table(data_source_type="database")
|
||||
file_keyword = _dataset_keyword_table(data_source_type="object_storage")
|
||||
|
||||
keyword_db = Jieba(_dataset(db_keyword))
|
||||
keyword_db.delete()
|
||||
patched_runtime.storage.delete.assert_not_called()
|
||||
|
||||
keyword_file = Jieba(_dataset(file_keyword))
|
||||
keyword_file.delete()
|
||||
|
||||
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
|
||||
assert patched_runtime.session.delete.call_count == 2
|
||||
assert patched_runtime.session.commit.call_count == 2
|
||||
|
||||
|
||||
def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime):
|
||||
dataset_keyword_table = _dataset_keyword_table(data_source_type="database")
|
||||
keyword = Jieba(_dataset(dataset_keyword_table))
|
||||
|
||||
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
|
||||
|
||||
assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table
|
||||
assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table
|
||||
patched_runtime.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime):
|
||||
dataset_keyword_table = _dataset_keyword_table(data_source_type="file")
|
||||
keyword = Jieba(_dataset(dataset_keyword_table))
|
||||
patched_runtime.storage.exists.return_value = True
|
||||
|
||||
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
|
||||
|
||||
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
|
||||
patched_runtime.storage.save.assert_called_once()
|
||||
save_args = patched_runtime.storage.save.call_args.args
|
||||
assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt"
|
||||
assert isinstance(save_args[1], bytes)
|
||||
|
||||
|
||||
def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime):
|
||||
existing = _dataset_keyword_table(
|
||||
keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}}
|
||||
)
|
||||
keyword = Jieba(_dataset(existing))
|
||||
assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]}
|
||||
|
||||
missing_payload = _dataset_keyword_table(keyword_table_dict=None)
|
||||
keyword_with_missing_payload = Jieba(_dataset(missing_payload))
|
||||
assert keyword_with_missing_payload._get_dataset_keyword_table() == {}
|
||||
|
||||
|
||||
def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime):
|
||||
created_tables: list[SimpleNamespace] = []
|
||||
|
||||
def _fake_dataset_keyword_table(**kwargs):
|
||||
kwargs.setdefault("keyword_table", "")
|
||||
kwargs.setdefault("keyword_table_dict", None)
|
||||
table = SimpleNamespace(**kwargs)
|
||||
created_tables.append(table)
|
||||
return table
|
||||
|
||||
keyword = Jieba(_dataset(dataset_keyword_table=None))
|
||||
monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table)
|
||||
monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database")
|
||||
|
||||
result = keyword._get_dataset_keyword_table()
|
||||
|
||||
assert result == {}
|
||||
assert len(created_tables) == 1
|
||||
assert created_tables[0].dataset_id == "dataset-1"
|
||||
assert created_tables[0].data_source_type == "database"
|
||||
assert '"index_id":"dataset-1"' in created_tables[0].keyword_table
|
||||
patched_runtime.session.add.assert_called_once_with(created_tables[0])
|
||||
patched_runtime.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_add_and_delete_ids_from_keyword_table_helpers():
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}}
|
||||
|
||||
updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"])
|
||||
assert updated["kw1"] == {"node-1", "node-3"}
|
||||
assert updated["kw3"] == {"node-3"}
|
||||
|
||||
deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"])
|
||||
assert "kw3" not in deleted
|
||||
assert "kw1" not in deleted
|
||||
assert deleted["kw2"] == {"node-2"}
|
||||
|
||||
|
||||
def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
handler = MagicMock()
|
||||
handler.extract_keywords.return_value = ["kw-a", "kw-b"]
|
||||
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
|
||||
|
||||
ranked_ids = keyword._retrieve_ids_by_query(
|
||||
{"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}},
|
||||
"query",
|
||||
k=1,
|
||||
)
|
||||
|
||||
assert ranked_ids == ["node-2"]
|
||||
|
||||
|
||||
def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime):
|
||||
class _FakeDocumentSegment:
|
||||
dataset_id = _Field("dataset_id")
|
||||
index_node_id = _Field("index_node_id")
|
||||
|
||||
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
|
||||
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
|
||||
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
segment = SimpleNamespace(keywords=[])
|
||||
patched_runtime.session.scalar.return_value = segment
|
||||
|
||||
keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"])
|
||||
|
||||
assert segment.keywords == ["kw1", "kw2"]
|
||||
patched_runtime.session.add.assert_called_once_with(segment)
|
||||
patched_runtime.session.commit.assert_called_once()
|
||||
|
||||
patched_runtime.session.reset_mock()
|
||||
patched_runtime.session.scalar.return_value = None
|
||||
|
||||
keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"])
|
||||
|
||||
patched_runtime.session.add.assert_not_called()
|
||||
patched_runtime.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
|
||||
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
keyword.create_segment_keywords("node-1", ["kw"])
|
||||
keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"])
|
||||
keyword._save_dataset_keyword_table.assert_called_once()
|
||||
|
||||
keyword._save_dataset_keyword_table.reset_mock()
|
||||
keyword.update_segment_keywords_index("node-2", ["kw2"])
|
||||
keyword._save_dataset_keyword_table.assert_called_once()
|
||||
|
||||
|
||||
def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch):
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2))
|
||||
handler = MagicMock()
|
||||
handler.extract_keywords.return_value = {"auto"}
|
||||
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
|
||||
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
|
||||
|
||||
first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None)
|
||||
second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None)
|
||||
|
||||
keyword.multi_create_segment_keywords(
|
||||
[
|
||||
{"segment": first_segment, "keywords": ["manual"]},
|
||||
{"segment": second_segment, "keywords": []},
|
||||
]
|
||||
)
|
||||
|
||||
assert first_segment.keywords == ["manual"]
|
||||
assert second_segment.keywords == ["auto"]
|
||||
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
|
||||
assert saved_table["manual"] == {"node-1"}
|
||||
assert saved_table["auto"] == {"node-2"}
|
||||
|
||||
|
||||
def test_set_orjson_default_and_dumps_with_sets():
|
||||
assert set(set_orjson_default({"a", "b"})) == {"a", "b"}
|
||||
|
||||
with pytest.raises(TypeError, match="is not JSON serializable"):
|
||||
set_orjson_default(("not", "a", "set"))
|
||||
|
||||
payload = {"items": {"a", "b"}}
|
||||
json_payload = dumps_with_sets(payload)
|
||||
decoded = json.loads(json_payload)
|
||||
assert set(decoded["items"]) == {"a", "b"}
|
||||
@ -0,0 +1,142 @@
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class _DummyTFIDF:
|
||||
def __init__(self):
|
||||
self.stop_words = set()
|
||||
|
||||
@staticmethod
|
||||
def extract_tags(sentence: str, top_k: int | None = 20, **kwargs):
|
||||
return ["alpha_beta", "during", "gamma"]
|
||||
|
||||
|
||||
def _install_fake_jieba_modules(
|
||||
monkeypatch,
|
||||
analyse_module: types.ModuleType,
|
||||
jieba_attrs: dict[str, object] | None = None,
|
||||
tfidf_module: types.ModuleType | None = None,
|
||||
):
|
||||
jieba_module = types.ModuleType("jieba")
|
||||
jieba_module.__path__ = []
|
||||
if jieba_attrs:
|
||||
for key, value in jieba_attrs.items():
|
||||
setattr(jieba_module, key, value)
|
||||
|
||||
jieba_module.analyse = analyse_module
|
||||
analyse_module.__package__ = "jieba"
|
||||
|
||||
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
|
||||
monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module)
|
||||
if tfidf_module is not None:
|
||||
monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module)
|
||||
else:
|
||||
monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False)
|
||||
|
||||
|
||||
def test_init_uses_existing_default_tfidf(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
default_tfidf = _DummyTFIDF()
|
||||
analyse_module.default_tfidf = default_tfidf
|
||||
|
||||
_install_fake_jieba_modules(monkeypatch, analyse_module)
|
||||
|
||||
handler = JiebaKeywordTableHandler()
|
||||
|
||||
assert handler._tfidf is default_tfidf
|
||||
assert handler._tfidf.stop_words == STOPWORDS
|
||||
|
||||
|
||||
def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
analyse_module.default_tfidf = None
|
||||
|
||||
class _TFIDFFactory(_DummyTFIDF):
|
||||
pass
|
||||
|
||||
analyse_module.TFIDF = _TFIDFFactory
|
||||
_install_fake_jieba_modules(monkeypatch, analyse_module)
|
||||
|
||||
handler = JiebaKeywordTableHandler()
|
||||
|
||||
assert isinstance(handler._tfidf, _TFIDFFactory)
|
||||
assert analyse_module.default_tfidf is handler._tfidf
|
||||
|
||||
|
||||
def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
analyse_module.default_tfidf = None
|
||||
|
||||
tfidf_module = types.ModuleType("jieba.analyse.tfidf")
|
||||
|
||||
class _ImportedTFIDF(_DummyTFIDF):
|
||||
pass
|
||||
|
||||
tfidf_module.TFIDF = _ImportedTFIDF
|
||||
_install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module)
|
||||
|
||||
handler = JiebaKeywordTableHandler()
|
||||
|
||||
assert isinstance(handler._tfidf, _ImportedTFIDF)
|
||||
assert analyse_module.default_tfidf is handler._tfidf
|
||||
|
||||
|
||||
def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
analyse_module.default_tfidf = None
|
||||
_install_fake_jieba_modules(monkeypatch, analyse_module)
|
||||
|
||||
handler = JiebaKeywordTableHandler()
|
||||
fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1)
|
||||
|
||||
assert fallback_keywords == ["two"]
|
||||
|
||||
|
||||
def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
_install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]})
|
||||
|
||||
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
|
||||
|
||||
assert tfidf.extract_tags("ignored", topK=1) == ["x"]
|
||||
|
||||
|
||||
def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch):
|
||||
analyse_module = types.ModuleType("jieba.analyse")
|
||||
_install_fake_jieba_modules(
|
||||
monkeypatch,
|
||||
analyse_module,
|
||||
jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])},
|
||||
)
|
||||
|
||||
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
|
||||
|
||||
assert tfidf.extract_tags("ignored", topK=1) == ["foo"]
|
||||
|
||||
|
||||
def test_extract_keywords_expands_subtokens():
|
||||
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
|
||||
handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"])
|
||||
|
||||
keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3)
|
||||
|
||||
assert "alpha-beta" in keywords
|
||||
assert "alpha" in keywords
|
||||
assert "beta" in keywords
|
||||
assert "during" in keywords
|
||||
assert "gamma" in keywords
|
||||
|
||||
|
||||
def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens():
|
||||
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
|
||||
|
||||
expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"})
|
||||
|
||||
assert "alpha-during-beta" in expanded
|
||||
assert "alpha" in expanded
|
||||
assert "beta" in expanded
|
||||
assert "during" not in expanded
|
||||
@ -0,0 +1,6 @@
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
def test_stopwords_loaded():
|
||||
assert "during" in STOPWORDS
|
||||
assert "the" in STOPWORDS
|
||||
@ -0,0 +1,97 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class _KeywordThatRaises(BaseKeyword):
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
return super().create(texts, **kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
return super().add_texts(texts, **kwargs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return super().text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
return super().delete_by_ids(ids)
|
||||
|
||||
def delete(self):
|
||||
return super().delete()
|
||||
|
||||
def search(self, query: str, **kwargs):
|
||||
return super().search(query, **kwargs)
|
||||
|
||||
|
||||
class _KeywordForHelpers(BaseKeyword):
|
||||
def __init__(self, dataset, existing_ids: set[str] | None = None):
|
||||
super().__init__(dataset)
|
||||
self._existing_ids = existing_ids or set()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
return None
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return id in self._existing_ids
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
return None
|
||||
|
||||
def delete(self):
|
||||
return None
|
||||
|
||||
def search(self, query: str, **kwargs):
|
||||
return []
|
||||
|
||||
|
||||
def test_abstract_methods_raise_not_implemented():
|
||||
keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1"))
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.create([])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.add_texts([])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.text_exists("doc-1")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.delete_by_ids(["doc-1"])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.delete()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
keyword.search("query")
|
||||
|
||||
|
||||
def test_filter_duplicate_texts_removes_existing_doc_ids():
|
||||
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"})
|
||||
texts = [
|
||||
Document(page_content="keep", metadata={"doc_id": "keep"}),
|
||||
Document(page_content="duplicate", metadata={"doc_id": "duplicate"}),
|
||||
SimpleNamespace(page_content="without-metadata", metadata=None),
|
||||
]
|
||||
|
||||
filtered = keyword._filter_duplicate_texts(texts)
|
||||
|
||||
assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"]
|
||||
assert any(text.metadata is None for text in filtered)
|
||||
|
||||
|
||||
def test_get_uuids_returns_only_docs_with_metadata():
|
||||
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"))
|
||||
texts = [
|
||||
Document(page_content="doc-1", metadata={"doc_id": "doc-1"}),
|
||||
Document(page_content="doc-2", metadata={"doc_id": "doc-2"}),
|
||||
SimpleNamespace(page_content="doc-3", metadata=None),
|
||||
]
|
||||
|
||||
assert keyword._get_uuids(texts) == ["doc-1", "doc-2"]
|
||||
@ -0,0 +1,84 @@
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.keyword.keyword_type import KeyWordType
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_get_keyword_factory_returns_jieba_factory(monkeypatch):
|
||||
fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba")
|
||||
|
||||
class FakeJieba:
|
||||
pass
|
||||
|
||||
fake_module.Jieba = FakeJieba
|
||||
monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module)
|
||||
|
||||
assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba
|
||||
|
||||
|
||||
def test_get_keyword_factory_raises_for_unsupported_type():
|
||||
with pytest.raises(ValueError, match="Keyword store unsupported is not supported"):
|
||||
Keyword.get_keyword_factory("unsupported")
|
||||
|
||||
|
||||
def test_keyword_initialization_uses_configured_factory(monkeypatch):
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
fake_processor = MagicMock()
|
||||
|
||||
monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA)
|
||||
monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor))
|
||||
|
||||
keyword = Keyword(dataset)
|
||||
|
||||
assert keyword._keyword_processor is fake_processor
|
||||
|
||||
|
||||
def test_keyword_methods_forward_to_processor():
|
||||
processor = MagicMock()
|
||||
processor.text_exists.return_value = True
|
||||
processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})]
|
||||
|
||||
keyword = Keyword.__new__(Keyword)
|
||||
keyword._keyword_processor = processor
|
||||
|
||||
docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})]
|
||||
keyword.create(docs, foo="bar")
|
||||
keyword.add_texts(docs, batch=True)
|
||||
assert keyword.text_exists("doc-1") is True
|
||||
keyword.delete_by_ids(["doc-1"])
|
||||
keyword.delete()
|
||||
assert keyword.search("query", top_k=1) == processor.search.return_value
|
||||
|
||||
processor.create.assert_called_once_with(docs, foo="bar")
|
||||
processor.add_texts.assert_called_once_with(docs, batch=True)
|
||||
processor.text_exists.assert_called_once_with("doc-1")
|
||||
processor.delete_by_ids.assert_called_once_with(["doc-1"])
|
||||
processor.delete.assert_called_once()
|
||||
processor.search.assert_called_once_with("query", top_k=1)
|
||||
|
||||
|
||||
def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes():
|
||||
class Processor:
|
||||
value = 1
|
||||
|
||||
@staticmethod
|
||||
def custom():
|
||||
return "ok"
|
||||
|
||||
keyword = Keyword.__new__(Keyword)
|
||||
keyword._keyword_processor = Processor()
|
||||
|
||||
assert keyword.custom() == "ok"
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = keyword.value
|
||||
|
||||
keyword._keyword_processor = None
|
||||
with pytest.raises(AttributeError):
|
||||
_ = keyword.missing_method
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,74 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
|
||||
|
||||
|
||||
def test_validate_distance_function_accepts_supported_values():
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
|
||||
assert factory._validate_distance_function("cosine") == "cosine"
|
||||
assert factory._validate_distance_function("euclidean") == "euclidean"
|
||||
|
||||
|
||||
def test_validate_distance_function_rejects_unsupported_values():
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid distance function"):
|
||||
factory._validate_distance_function("dot_product")
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
|
||||
index_struct=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6)
|
||||
|
||||
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
|
||||
|
||||
|
||||
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2",
|
||||
index_struct_dict=None,
|
||||
index_struct=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean")
|
||||
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12)
|
||||
|
||||
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
vector_cls.assert_called_once()
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2"
|
||||
assert dataset.index_struct is not None
|
||||
@ -0,0 +1,133 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_init_prefers_openapi_when_api_config_is_provided():
|
||||
api_config = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="ak",
|
||||
access_key_secret="sk",
|
||||
region_id="cn-hangzhou",
|
||||
instance_id="instance-1",
|
||||
account="account",
|
||||
account_password="password",
|
||||
namespace="dify",
|
||||
namespace_password="ns-password",
|
||||
)
|
||||
|
||||
with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls:
|
||||
vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None)
|
||||
|
||||
assert vector.analyticdb_vector == "openapi_runner"
|
||||
openapi_cls.assert_called_once_with("COLLECTION", api_config)
|
||||
|
||||
|
||||
def test_init_uses_sql_implementation_when_api_config_is_missing():
|
||||
sql_config = AnalyticdbVectorBySqlConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
account="account",
|
||||
account_password="password",
|
||||
min_connection=1,
|
||||
max_connection=2,
|
||||
namespace="dify",
|
||||
)
|
||||
|
||||
with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls:
|
||||
vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config)
|
||||
|
||||
assert vector.analyticdb_vector == "sql_runner"
|
||||
sql_cls.assert_called_once_with("COLLECTION", sql_config)
|
||||
|
||||
|
||||
def test_init_raises_when_both_configs_are_missing():
|
||||
with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"):
|
||||
AnalyticdbVector("COLLECTION", api_config=None, sql_config=None)
|
||||
|
||||
|
||||
def test_vector_methods_delegate_to_underlying_implementation():
|
||||
runner = MagicMock()
|
||||
runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})]
|
||||
runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})]
|
||||
runner.text_exists.return_value = True
|
||||
|
||||
vector = AnalyticdbVector.__new__(AnalyticdbVector)
|
||||
vector.analyticdb_vector = runner
|
||||
|
||||
texts = [Document(page_content="hello", metadata={"doc_id": "d1"})]
|
||||
vector.create(texts=texts, embeddings=[[0.1, 0.2]])
|
||||
vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]])
|
||||
assert vector.text_exists("d1") is True
|
||||
vector.delete_by_ids(["d1"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value
|
||||
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
|
||||
vector.delete()
|
||||
|
||||
runner._create_collection_if_not_exists.assert_called_once_with(2)
|
||||
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
|
||||
runner.delete_by_ids.assert_called_once_with(["d1"])
|
||||
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")
|
||||
runner.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_get_type_is_analyticdb():
|
||||
vector = AnalyticdbVector.__new__(AnalyticdbVector)
|
||||
assert vector.get_type() == "analyticdb"
|
||||
|
||||
|
||||
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None)
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password")
|
||||
|
||||
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
args = vector_cls.call_args.args
|
||||
assert args[0] == "auto_collection"
|
||||
assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig)
|
||||
assert args[2] is None
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
)
|
||||
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432)
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1)
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3)
|
||||
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
|
||||
|
||||
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
args = vector_cls.call_args.args
|
||||
assert args[0] == "existing"
|
||||
assert args[1] is None
|
||||
assert isinstance(args[2], AnalyticdbVectorBySqlConfig)
|
||||
@ -0,0 +1,384 @@
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _request_class(name: str):
|
||||
class _Request:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
_Request.__name__ = name
|
||||
return _Request
|
||||
|
||||
|
||||
def _install_openapi_stubs(monkeypatch):
|
||||
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
|
||||
gpdb_package.__path__ = []
|
||||
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
|
||||
for class_name in [
|
||||
"InitVectorDatabaseRequest",
|
||||
"DescribeNamespaceRequest",
|
||||
"CreateNamespaceRequest",
|
||||
"DescribeCollectionRequest",
|
||||
"CreateCollectionRequest",
|
||||
"UpsertCollectionDataRequestRows",
|
||||
"UpsertCollectionDataRequest",
|
||||
"QueryCollectionDataRequest",
|
||||
"DeleteCollectionDataRequest",
|
||||
"DeleteCollectionRequest",
|
||||
]:
|
||||
setattr(gpdb_models, class_name, _request_class(class_name))
|
||||
|
||||
class _Client:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client")
|
||||
gpdb_client.Client = _Client
|
||||
gpdb_package.models = gpdb_models
|
||||
|
||||
tea_openapi = types.ModuleType("alibabacloud_tea_openapi")
|
||||
tea_openapi.__path__ = []
|
||||
tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models")
|
||||
|
||||
class OpenApiConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
tea_openapi_models.Config = OpenApiConfig
|
||||
tea_openapi.models = tea_openapi_models
|
||||
|
||||
tea_package = types.ModuleType("Tea")
|
||||
tea_package.__path__ = []
|
||||
tea_exceptions = types.ModuleType("Tea.exceptions")
|
||||
|
||||
class TeaError(Exception):
|
||||
def __init__(self, status_code=None, **kwargs):
|
||||
super().__init__("TeaException")
|
||||
status_code = kwargs.get("statusCode", status_code)
|
||||
self.statusCode = status_code
|
||||
self.status_code = status_code
|
||||
|
||||
tea_exceptions.TeaException = TeaError
|
||||
tea_package.exceptions = tea_exceptions
|
||||
|
||||
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package)
|
||||
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models)
|
||||
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client)
|
||||
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi)
|
||||
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models)
|
||||
monkeypatch.setitem(sys.modules, "Tea", tea_package)
|
||||
monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions)
|
||||
|
||||
return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig)
|
||||
|
||||
|
||||
def _config() -> AnalyticdbVectorOpenAPIConfig:
|
||||
return AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="ak",
|
||||
access_key_secret="sk",
|
||||
region_id="cn-hangzhou",
|
||||
instance_id="instance-1",
|
||||
account="account",
|
||||
account_password="password",
|
||||
namespace="dify",
|
||||
namespace_password="ns-password",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "error_message"),
|
||||
[
|
||||
("access_key_id", "", "ANALYTICDB_KEY_ID"),
|
||||
("access_key_secret", "", "ANALYTICDB_KEY_SECRET"),
|
||||
("region_id", "", "ANALYTICDB_REGION_ID"),
|
||||
("instance_id", "", "ANALYTICDB_INSTANCE_ID"),
|
||||
("account", "", "ANALYTICDB_ACCOUNT"),
|
||||
("account_password", "", "ANALYTICDB_PASSWORD"),
|
||||
("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
],
|
||||
)
|
||||
def test_openapi_config_validation(field, value, error_message):
|
||||
values = _config().model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
AnalyticdbVectorOpenAPIConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_openapi_config_to_client_params():
|
||||
config = _config()
|
||||
params = config.to_analyticdb_client_params()
|
||||
|
||||
assert params["access_key_id"] == "ak"
|
||||
assert params["access_key_secret"] == "sk"
|
||||
assert params["region_id"] == "cn-hangzhou"
|
||||
assert params["read_timeout"] == 60000
|
||||
|
||||
|
||||
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
initialize_mock = MagicMock()
|
||||
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config())
|
||||
|
||||
assert vector._collection_name == "collection_1"
|
||||
assert isinstance(vector._client_config, stubs.OpenApiConfig)
|
||||
assert vector._client_config.user_agent == "dify"
|
||||
assert vector._client_config.access_key_id == "ak"
|
||||
assert vector._client.config is vector._client_config
|
||||
initialize_mock.assert_called_once_with()
|
||||
|
||||
|
||||
def test_initialize_skips_when_cached(monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._initialize_vector_database = MagicMock()
|
||||
vector._create_namespace_if_not_exists = MagicMock()
|
||||
|
||||
vector._initialize()
|
||||
|
||||
vector._initialize_vector_database.assert_not_called()
|
||||
vector._create_namespace_if_not_exists.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._initialize_vector_database = MagicMock()
|
||||
vector._create_namespace_if_not_exists = MagicMock()
|
||||
|
||||
vector._initialize()
|
||||
|
||||
vector._initialize_vector_database.assert_called_once()
|
||||
vector._create_namespace_if_not_exists.assert_called_once()
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._initialize_vector_database()
|
||||
|
||||
request = vector._client.init_vector_database.call_args.args[0]
|
||||
assert request.dbinstance_id == "instance-1"
|
||||
assert request.region_id == "cn-hangzhou"
|
||||
assert request.manager_account == "account"
|
||||
assert request.manager_account_password == "password"
|
||||
|
||||
|
||||
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404)
|
||||
|
||||
vector._create_namespace_if_not_exists()
|
||||
|
||||
vector._client.create_namespace.assert_called_once()
|
||||
|
||||
|
||||
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500)
|
||||
|
||||
with pytest.raises(ValueError, match="failed to create namespace"):
|
||||
vector._create_namespace_if_not_exists()
|
||||
|
||||
|
||||
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._create_namespace_if_not_exists()
|
||||
|
||||
vector._client.describe_namespace.assert_called_once()
|
||||
vector._client.create_namespace.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.create_collection.assert_called_once()
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.describe_collection.assert_not_called()
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
|
||||
|
||||
with pytest.raises(ValueError, match="failed to create collection collection_1"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=512)
|
||||
|
||||
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
documents = [
|
||||
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
|
||||
SimpleNamespace(page_content="doc 2", metadata=None),
|
||||
]
|
||||
embeddings = [[0.1, 0.2], [0.2, 0.3]]
|
||||
vector.add_texts(documents, embeddings)
|
||||
|
||||
upsert_request = vector._client.upsert_collection_data.call_args.args[0]
|
||||
assert upsert_request.collection == "collection_1"
|
||||
assert len(upsert_request.rows) == 1
|
||||
|
||||
vector._client.query_collection_data.return_value = SimpleNamespace(
|
||||
body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()]))
|
||||
)
|
||||
assert vector.text_exists("d1") is True
|
||||
|
||||
vector.delete_by_ids(["d1", "d2"])
|
||||
request = vector._client.delete_collection_data.call_args.args[0]
|
||||
assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')"
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
request = vector._client.delete_collection_data.call_args.args[0]
|
||||
assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'"
|
||||
|
||||
match_high = SimpleNamespace(
|
||||
score=0.9,
|
||||
metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"},
|
||||
values=SimpleNamespace(value=[1.0, 2.0]),
|
||||
)
|
||||
match_low = SimpleNamespace(
|
||||
score=0.1,
|
||||
metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"},
|
||||
values=SimpleNamespace(value=[3.0, 4.0]),
|
||||
)
|
||||
vector._client.query_collection_data.return_value = SimpleNamespace(
|
||||
body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high]))
|
||||
)
|
||||
|
||||
docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
|
||||
assert len(docs_by_vector) == 1
|
||||
assert docs_by_vector[0].page_content == "high"
|
||||
assert docs_by_vector[0].metadata["score"] == 0.9
|
||||
|
||||
docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2)
|
||||
assert len(docs_by_text) == 1
|
||||
assert docs_by_text[0].page_content == "high"
|
||||
|
||||
|
||||
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.query_collection_data.return_value = SimpleNamespace(
|
||||
body=SimpleNamespace(matches=SimpleNamespace(match=[]))
|
||||
)
|
||||
|
||||
assert vector.text_exists("missing-id") is False
|
||||
|
||||
|
||||
def test_openapi_delete_success(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector.delete()
|
||||
vector._client.delete_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_openapi_delete_propagates_errors(monkeypatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
vector._client.delete_collection.side_effect = RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
vector.delete()
|
||||
@ -0,0 +1,427 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
|
||||
AnalyticdbVectorBySql,
|
||||
AnalyticdbVectorBySqlConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _config_values() -> dict:
|
||||
return {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"account": "account",
|
||||
"account_password": "password",
|
||||
"min_connection": 1,
|
||||
"max_connection": 2,
|
||||
"namespace": "dify",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "error_message"),
|
||||
[
|
||||
("host", "", "ANALYTICDB_HOST"),
|
||||
("port", 0, "ANALYTICDB_PORT"),
|
||||
("account", "", "ANALYTICDB_ACCOUNT"),
|
||||
("account_password", "", "ANALYTICDB_PASSWORD"),
|
||||
("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"),
|
||||
("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"),
|
||||
],
|
||||
)
|
||||
def test_sql_config_required_fields(field, value, error_message):
|
||||
values = _config_values()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
AnalyticdbVectorBySqlConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_sql_config_rejects_min_connection_greater_than_max_connection():
|
||||
values = _config_values()
|
||||
values["min_connection"] = 10
|
||||
values["max_connection"] = 2
|
||||
|
||||
with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"):
|
||||
AnalyticdbVectorBySqlConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_initialize_skips_when_cache_exists(monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._initialize_vector_database = MagicMock()
|
||||
|
||||
vector._initialize()
|
||||
|
||||
vector._initialize_vector_database.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._initialize_vector_database = MagicMock()
|
||||
|
||||
vector._initialize()
|
||||
|
||||
vector._initialize_vector_database.assert_called_once()
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
|
||||
pool_instance = MagicMock()
|
||||
monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance))
|
||||
|
||||
pool = vector._create_connection_pool()
|
||||
|
||||
assert pool is pool_instance
|
||||
sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once()
|
||||
|
||||
|
||||
def test_get_cursor_context_manager_handles_connection_lifecycle():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
cursor = MagicMock()
|
||||
connection = MagicMock()
|
||||
connection.cursor.return_value = cursor
|
||||
pool = MagicMock()
|
||||
pool.getconn.return_value = connection
|
||||
vector.pool = pool
|
||||
|
||||
with vector._get_cursor() as cur:
|
||||
assert cur is cursor
|
||||
|
||||
cursor.close.assert_called_once()
|
||||
connection.commit.assert_called_once()
|
||||
pool.putconn.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id")
|
||||
monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock())
|
||||
|
||||
docs = [
|
||||
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
|
||||
SimpleNamespace(page_content="doc 2", metadata=None),
|
||||
]
|
||||
vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]])
|
||||
|
||||
execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args
|
||||
assert execute_args[0] is cursor
|
||||
assert len(execute_args[2]) == 1
|
||||
|
||||
|
||||
def test_text_exists_returns_true_and_false_based_on_query_result():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
cursor.fetchone.return_value = ("row",)
|
||||
assert vector.text_exists("d1") is True
|
||||
|
||||
cursor.fetchone.return_value = None
|
||||
assert vector.text_exists("d1") is False
|
||||
|
||||
|
||||
def test_delete_by_ids_handles_empty_input_and_missing_table_error():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
vector.delete_by_ids([])
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
|
||||
vector.delete_by_ids(["d1"])
|
||||
|
||||
|
||||
def test_delete_by_metadata_field_handles_missing_table_error():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
|
||||
def test_search_by_vector_validates_top_k(invalid_top_k):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k)
|
||||
|
||||
|
||||
def test_search_by_vector_returns_documents_above_threshold():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter(
|
||||
[
|
||||
("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}),
|
||||
("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}),
|
||||
]
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "content 1"
|
||||
assert docs[0].metadata["score"] == 0.8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
|
||||
def test_search_by_full_text_validates_top_k(invalid_top_k):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_full_text("query", top_k=invalid_top_k)
|
||||
|
||||
|
||||
def test_search_by_full_text_returns_documents():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter(
|
||||
[
|
||||
("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9),
|
||||
]
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
assert docs[0].page_content == "content 1"
|
||||
|
||||
|
||||
def test_delete_drops_table():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
vector.delete()
|
||||
|
||||
cursor.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
|
||||
config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
created_pool = MagicMock()
|
||||
|
||||
monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock())
|
||||
monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool))
|
||||
|
||||
vector = AnalyticdbVectorBySql("My_Collection", config)
|
||||
|
||||
assert vector._collection_name == "my_collection"
|
||||
assert vector.table_name == "dify.my_collection"
|
||||
assert vector.databaseName == "knowledgebase"
|
||||
assert vector.pool is created_pool
|
||||
|
||||
|
||||
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
|
||||
bootstrap_cursor = MagicMock()
|
||||
bootstrap_connection = MagicMock()
|
||||
bootstrap_connection.cursor.return_value = bootstrap_cursor
|
||||
bootstrap_cursor.execute.side_effect = RuntimeError("database already exists")
|
||||
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
|
||||
|
||||
worker_cursor = MagicMock()
|
||||
worker_connection = MagicMock()
|
||||
worker_cursor.connection = worker_connection
|
||||
|
||||
def _execute(sql, *args, **kwargs):
|
||||
if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql:
|
||||
raise RuntimeError("already exists")
|
||||
|
||||
worker_cursor.execute.side_effect = _execute
|
||||
pooled_connection = MagicMock()
|
||||
pooled_connection.cursor.return_value = worker_cursor
|
||||
pool = MagicMock()
|
||||
pool.getconn.return_value = pooled_connection
|
||||
vector._create_connection_pool = MagicMock(return_value=pool)
|
||||
|
||||
vector._initialize_vector_database()
|
||||
|
||||
bootstrap_cursor.close.assert_called_once()
|
||||
bootstrap_connection.close.assert_called_once()
|
||||
vector._create_connection_pool.assert_called_once()
|
||||
assert any(
|
||||
"CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0]
|
||||
for call in worker_cursor.execute.call_args_list
|
||||
)
|
||||
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
|
||||
|
||||
|
||||
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
|
||||
bootstrap_cursor = MagicMock()
|
||||
bootstrap_connection = MagicMock()
|
||||
bootstrap_connection.cursor.return_value = bootstrap_cursor
|
||||
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
|
||||
|
||||
worker_cursor = MagicMock()
|
||||
worker_connection = MagicMock()
|
||||
worker_cursor.connection = worker_connection
|
||||
worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable")
|
||||
|
||||
pooled_connection = MagicMock()
|
||||
pooled_connection.cursor.return_value = worker_cursor
|
||||
pool = MagicMock()
|
||||
pool.getconn.return_value = pooled_connection
|
||||
vector._create_connection_pool = MagicMock(return_value=pool)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to create zhparser extension"):
|
||||
vector._initialize_vector_database()
|
||||
|
||||
worker_connection.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.execute.side_effect = RuntimeError("permission denied")
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
with pytest.raises(RuntimeError, match="permission denied"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
|
||||
def test_delete_methods_raise_when_error_is_not_missing_table():
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_context():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
cursor.execute.side_effect = RuntimeError("unexpected delete failure")
|
||||
with pytest.raises(RuntimeError, match="unexpected delete failure"):
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
|
||||
cursor.execute.side_effect = RuntimeError("unexpected metadata failure")
|
||||
with pytest.raises(RuntimeError, match="unexpected metadata failure"):
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
@ -0,0 +1,542 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_pymochow_modules():
|
||||
pymochow = types.ModuleType("pymochow")
|
||||
pymochow.__path__ = []
|
||||
pymochow_auth = types.ModuleType("pymochow.auth")
|
||||
pymochow_auth.__path__ = []
|
||||
pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials")
|
||||
pymochow_configuration = types.ModuleType("pymochow.configuration")
|
||||
pymochow_exception = types.ModuleType("pymochow.exception")
|
||||
pymochow_model = types.ModuleType("pymochow.model")
|
||||
pymochow_model.__path__ = []
|
||||
pymochow_model_database = types.ModuleType("pymochow.model.database")
|
||||
pymochow_model_enum = types.ModuleType("pymochow.model.enum")
|
||||
pymochow_model_schema = types.ModuleType("pymochow.model.schema")
|
||||
pymochow_model_table = types.ModuleType("pymochow.model.table")
|
||||
|
||||
class _SimpleObject:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
class ServerError(Exception):
|
||||
def __init__(self, code):
|
||||
super().__init__(f"server error {code}")
|
||||
self.code = code
|
||||
|
||||
class ServerErrCode:
|
||||
TABLE_NOT_EXIST = 1001
|
||||
DB_ALREADY_EXIST = 1002
|
||||
|
||||
class IndexType:
|
||||
__members__ = {"HNSW": "HNSW"}
|
||||
|
||||
class MetricType:
|
||||
__members__ = {"IP": "IP"}
|
||||
|
||||
class IndexState:
|
||||
NORMAL = "NORMAL"
|
||||
|
||||
class TableState:
|
||||
NORMAL = "NORMAL"
|
||||
|
||||
class InvertedIndexAnalyzer:
|
||||
DEFAULT_ANALYZER = "DEFAULT_ANALYZER"
|
||||
|
||||
class InvertedIndexParseMode:
|
||||
COARSE_MODE = "COARSE_MODE"
|
||||
|
||||
class InvertedIndexFieldAttribute:
|
||||
ANALYZED = "ANALYZED"
|
||||
|
||||
class FieldType:
|
||||
STRING = "STRING"
|
||||
TEXT = "TEXT"
|
||||
JSON = "JSON"
|
||||
FLOAT_VECTOR = "FLOAT_VECTOR"
|
||||
|
||||
pymochow.MochowClient = _SimpleObject
|
||||
pymochow_credentials.BceCredentials = _SimpleObject
|
||||
pymochow_configuration.Configuration = _SimpleObject
|
||||
pymochow_exception.ServerError = ServerError
|
||||
pymochow_model_database.Database = _SimpleObject
|
||||
|
||||
pymochow_model_enum.FieldType = FieldType
|
||||
pymochow_model_enum.IndexState = IndexState
|
||||
pymochow_model_enum.IndexType = IndexType
|
||||
pymochow_model_enum.MetricType = MetricType
|
||||
pymochow_model_enum.ServerErrCode = ServerErrCode
|
||||
pymochow_model_enum.TableState = TableState
|
||||
|
||||
for cls_name in [
|
||||
"AutoBuildRowCountIncrement",
|
||||
"Field",
|
||||
"FilteringIndex",
|
||||
"HNSWParams",
|
||||
"InvertedIndex",
|
||||
"InvertedIndexParams",
|
||||
"Schema",
|
||||
"VectorIndex",
|
||||
]:
|
||||
setattr(pymochow_model_schema, cls_name, _SimpleObject)
|
||||
pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer
|
||||
pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute
|
||||
pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode
|
||||
|
||||
for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]:
|
||||
setattr(pymochow_model_table, cls_name, _SimpleObject)
|
||||
|
||||
pymochow.auth = pymochow_auth
|
||||
pymochow.model = pymochow_model
|
||||
pymochow_auth.bce_credentials = pymochow_credentials
|
||||
pymochow_model.database = pymochow_model_database
|
||||
pymochow_model.enum = pymochow_model_enum
|
||||
pymochow_model.schema = pymochow_model_schema
|
||||
pymochow_model.table = pymochow_model_table
|
||||
|
||||
modules = {
|
||||
"pymochow": pymochow,
|
||||
"pymochow.auth": pymochow_auth,
|
||||
"pymochow.auth.bce_credentials": pymochow_credentials,
|
||||
"pymochow.configuration": pymochow_configuration,
|
||||
"pymochow.exception": pymochow_exception,
|
||||
"pymochow.model": pymochow_model,
|
||||
"pymochow.model.database": pymochow_model_database,
|
||||
"pymochow.model.enum": pymochow_model_enum,
|
||||
"pymochow.model.schema": pymochow_model_schema,
|
||||
"pymochow.model.table": pymochow_model_table,
|
||||
}
|
||||
return modules
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baidu_module(monkeypatch):
|
||||
for name, module in _build_fake_pymochow_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
import core.rag.datasource.vdb.baidu.baidu_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def test_baidu_config_validation(baidu_module):
|
||||
values = {
|
||||
"endpoint": "https://example.com",
|
||||
"account": "account",
|
||||
"api_key": "key",
|
||||
"database": "database",
|
||||
}
|
||||
config = baidu_module.BaiduConfig.model_validate(values)
|
||||
assert config.endpoint == "https://example.com"
|
||||
|
||||
for key, error_message in [
|
||||
("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"),
|
||||
("account", "BAIDU_VECTOR_DB_ACCOUNT"),
|
||||
("api_key", "BAIDU_VECTOR_DB_API_KEY"),
|
||||
("database", "BAIDU_VECTOR_DB_DATABASE"),
|
||||
]:
|
||||
invalid = dict(values)
|
||||
invalid[key] = ""
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
baidu_module.BaiduConfig.model_validate(invalid)
|
||||
|
||||
|
||||
def test_get_search_result_handles_metadata_and_threshold(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
response = SimpleNamespace(
|
||||
rows=[
|
||||
{"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9},
|
||||
{"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4},
|
||||
{"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95},
|
||||
]
|
||||
)
|
||||
|
||||
docs = vector._get_search_res(response, score_threshold=0.8)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "doc1"
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
assert docs[1].page_content == "doc3"
|
||||
|
||||
|
||||
def test_delete_by_ids_and_delete_by_metadata_field(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
table = MagicMock()
|
||||
vector._db = MagicMock()
|
||||
vector._db.table.return_value = table
|
||||
vector._collection_name = "collection_1"
|
||||
|
||||
vector.delete_by_ids([])
|
||||
table.delete.assert_not_called()
|
||||
|
||||
vector.delete_by_ids(["id1", "id2"])
|
||||
table.delete.assert_called_once()
|
||||
|
||||
table.delete.reset_mock()
|
||||
vector.delete_by_metadata_field("source", 'abc"def')
|
||||
delete_filter = table.delete.call_args.kwargs["filter"]
|
||||
assert delete_filter == 'metadata["source"] = "abc\\"def"'
|
||||
|
||||
|
||||
def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
|
||||
vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
|
||||
vector.delete()
|
||||
|
||||
vector._db.drop_table.side_effect = baidu_module.ServerError(9999)
|
||||
with pytest.raises(baidu_module.ServerError):
|
||||
vector.delete()
|
||||
|
||||
|
||||
def test_init_database_uses_existing_or_creates_when_missing(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._client = MagicMock()
|
||||
vector._client_config = SimpleNamespace(database="my_db")
|
||||
|
||||
vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")]
|
||||
vector._client.database.return_value = "existing_db"
|
||||
assert vector._init_database() == "existing_db"
|
||||
|
||||
vector._client.list_databases.return_value = []
|
||||
vector._client.database.return_value = "created_db"
|
||||
vector._client.create_database.side_effect = None
|
||||
assert vector._init_database() == "created_db"
|
||||
|
||||
vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST)
|
||||
assert vector._init_database() == "created_db"
|
||||
|
||||
|
||||
def test_table_existed_checks_table_access(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
vector._db.table.return_value = MagicMock()
|
||||
|
||||
assert vector._table_existed() is True
|
||||
|
||||
vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
|
||||
assert vector._table_existed() is False
|
||||
|
||||
vector._db.table.side_effect = baidu_module.ServerError(9999)
|
||||
with pytest.raises(baidu_module.ServerError):
|
||||
vector._table_existed()
|
||||
|
||||
|
||||
def test_search_methods_delegate_to_database_table(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})])
|
||||
|
||||
table = MagicMock()
|
||||
vector._db.table.return_value = table
|
||||
table.search.return_value = "vector_result"
|
||||
table.bm25_search.return_value = "bm25_result"
|
||||
|
||||
result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
|
||||
result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
|
||||
|
||||
assert result1 == vector._get_search_res.return_value
|
||||
assert result2 == vector._get_search_res.return_value
|
||||
assert vector._get_search_res.call_count == 2
|
||||
|
||||
|
||||
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
|
||||
|
||||
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
|
||||
init_client = MagicMock(return_value="client")
|
||||
init_database = MagicMock(return_value="database")
|
||||
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
|
||||
monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database)
|
||||
|
||||
config = baidu_module.BaiduConfig(
|
||||
endpoint="https://example.com",
|
||||
account="account",
|
||||
api_key="key",
|
||||
database="db",
|
||||
)
|
||||
vector = baidu_module.BaiduVector(collection_name="my_collection", config=config)
|
||||
|
||||
assert vector.get_type() == baidu_module.VectorType.BAIDU
|
||||
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection"
|
||||
assert vector._client == "client"
|
||||
assert vector._db == "database"
|
||||
|
||||
vector._create_table = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="p1", metadata={"doc_id": "d1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
vector._create_table.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_batches_rows(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
table = MagicMock()
|
||||
vector._db = MagicMock()
|
||||
vector._db.table.return_value = table
|
||||
|
||||
docs = [
|
||||
Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}),
|
||||
Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}),
|
||||
]
|
||||
vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
assert table.upsert.call_count == 1
|
||||
inserted_rows = table.upsert.call_args.kwargs["rows"]
|
||||
assert len(inserted_rows) == 2
|
||||
|
||||
|
||||
def test_add_texts_batches_more_than_batch_size(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
table = MagicMock()
|
||||
vector._db = MagicMock()
|
||||
vector._db.table.return_value = table
|
||||
|
||||
docs = [
|
||||
Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"})
|
||||
for idx in range(1001)
|
||||
]
|
||||
embeddings = [[0.1, 0.2] for _ in range(1001)]
|
||||
|
||||
vector.add_texts(docs, embeddings)
|
||||
|
||||
assert table.upsert.call_count == 2
|
||||
assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000
|
||||
assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1
|
||||
|
||||
|
||||
def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
table = MagicMock()
|
||||
vector._db = MagicMock()
|
||||
vector._db.table.return_value = table
|
||||
|
||||
table.query.return_value = SimpleNamespace(code=0)
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
table.query.return_value = SimpleNamespace(code=1)
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
table.query.return_value = None
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
|
||||
def test_get_search_result_handles_invalid_metadata_json(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}])
|
||||
|
||||
docs = vector._get_search_res(response, score_threshold=0.1)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.7
|
||||
assert "document_id" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
|
||||
credentials = MagicMock(return_value="credentials")
|
||||
configuration = MagicMock(return_value="configuration")
|
||||
client_cls = MagicMock(return_value="client")
|
||||
monkeypatch.setattr(baidu_module, "BceCredentials", credentials)
|
||||
monkeypatch.setattr(baidu_module, "Configuration", configuration)
|
||||
monkeypatch.setattr(baidu_module, "MochowClient", client_cls)
|
||||
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint")
|
||||
|
||||
client = vector._init_client(config)
|
||||
|
||||
assert client == "client"
|
||||
credentials.assert_called_once_with("account", "key")
|
||||
configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint")
|
||||
client_cls.assert_called_once_with("configuration")
|
||||
|
||||
|
||||
def test_init_database_raises_for_unknown_create_database_error(baidu_module):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._client = MagicMock()
|
||||
vector._client_config = SimpleNamespace(database="my_db")
|
||||
vector._client.list_databases.return_value = []
|
||||
vector._client.create_database.side_effect = baidu_module.ServerError(9999)
|
||||
|
||||
with pytest.raises(baidu_module.ServerError):
|
||||
vector._init_database()
|
||||
|
||||
|
||||
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
index_type="HNSW",
|
||||
metric_type="IP",
|
||||
inverted_index_analyzer="DEFAULT_ANALYZER",
|
||||
inverted_index_parser_mode="COARSE_MODE",
|
||||
auto_build_row_count_increment=500,
|
||||
auto_build_row_count_increment_ratio=0.05,
|
||||
rebuild_index_timeout_in_seconds=300,
|
||||
replicas=1,
|
||||
shard=1,
|
||||
)
|
||||
vector._db = MagicMock()
|
||||
table = MagicMock()
|
||||
table.state = baidu_module.TableState.NORMAL
|
||||
vector._db.describe_table.return_value = table
|
||||
vector._table_existed = MagicMock(return_value=False)
|
||||
vector.delete = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock())
|
||||
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
|
||||
monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock())
|
||||
|
||||
# Cached table skips all work.
|
||||
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_table(3)
|
||||
vector._db.create_table.assert_not_called()
|
||||
|
||||
# Existing table also skips creation.
|
||||
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._table_existed.return_value = True
|
||||
vector._create_table(3)
|
||||
vector._db.create_table.assert_not_called()
|
||||
|
||||
# Create table when cache is empty and table does not exist.
|
||||
vector._table_existed.return_value = False
|
||||
vector._create_table(3)
|
||||
vector._db.create_table.assert_called_once()
|
||||
baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600)
|
||||
table.rebuild_index.assert_called_once_with(vector.vector_index)
|
||||
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
|
||||
|
||||
|
||||
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
vector._table_existed = MagicMock(return_value=False)
|
||||
vector.delete = MagicMock()
|
||||
vector._client_config = SimpleNamespace(
|
||||
index_type="INVALID",
|
||||
metric_type="IP",
|
||||
inverted_index_analyzer="DEFAULT_ANALYZER",
|
||||
inverted_index_parser_mode="COARSE_MODE",
|
||||
auto_build_row_count_increment=500,
|
||||
auto_build_row_count_increment_ratio=0.05,
|
||||
rebuild_index_timeout_in_seconds=300,
|
||||
replicas=1,
|
||||
shard=1,
|
||||
)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported index_type"):
|
||||
vector._create_table(3)
|
||||
|
||||
vector._client_config.index_type = "HNSW"
|
||||
vector._client_config.metric_type = "INVALID"
|
||||
with pytest.raises(ValueError, match="unsupported metric_type"):
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
index_type="HNSW",
|
||||
metric_type="IP",
|
||||
inverted_index_analyzer="DEFAULT_ANALYZER",
|
||||
inverted_index_parser_mode="COARSE_MODE",
|
||||
auto_build_row_count_increment=500,
|
||||
auto_build_row_count_increment_ratio=0.05,
|
||||
rebuild_index_timeout_in_seconds=300,
|
||||
replicas=1,
|
||||
shard=1,
|
||||
)
|
||||
vector._db = MagicMock()
|
||||
vector._db.describe_table.return_value = SimpleNamespace(state="CREATING")
|
||||
vector._table_existed = MagicMock(return_value=False)
|
||||
vector.delete = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
|
||||
monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301]))
|
||||
|
||||
with pytest.raises(TimeoutError, match="Table creation timeout"):
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
|
||||
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
|
||||
|
||||
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
|
||||
@ -0,0 +1,199 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from collections import UserDict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_chroma_modules():
|
||||
chroma = types.ModuleType("chromadb")
|
||||
chroma.DEFAULT_TENANT = "default_tenant"
|
||||
chroma.DEFAULT_DATABASE = "default_database"
|
||||
|
||||
class Settings:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
class QueryResult(UserDict):
|
||||
pass
|
||||
|
||||
class _Collection:
|
||||
def __init__(self):
|
||||
self.upsert = MagicMock()
|
||||
self.delete = MagicMock()
|
||||
self.query = MagicMock()
|
||||
self.get = MagicMock(return_value={})
|
||||
|
||||
class _Client:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.collection = _Collection()
|
||||
self.get_or_create_collection = MagicMock(return_value=self.collection)
|
||||
self.delete_collection = MagicMock()
|
||||
|
||||
chroma.Settings = Settings
|
||||
chroma.QueryResult = QueryResult
|
||||
chroma.HttpClient = _Client
|
||||
return chroma
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_module(monkeypatch):
|
||||
fake_chroma = _build_fake_chroma_modules()
|
||||
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
|
||||
import core.rag.datasource.vdb.chroma.chroma_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def test_chroma_config_to_params_builds_expected_payload(chroma_module):
|
||||
config = chroma_module.ChromaConfig(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
tenant="tenant-1",
|
||||
database="db-1",
|
||||
auth_provider="provider",
|
||||
auth_credentials="credentials",
|
||||
)
|
||||
|
||||
params = config.to_chroma_params()
|
||||
|
||||
assert params["host"] == "localhost"
|
||||
assert params["port"] == 8000
|
||||
assert params["tenant"] == "tenant-1"
|
||||
assert params["database"] == "db-1"
|
||||
assert params["ssl"] is False
|
||||
assert params["settings"].chroma_client_auth_provider == "provider"
|
||||
assert params["settings"].chroma_client_auth_credentials == "credentials"
|
||||
|
||||
|
||||
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
vector.create_collection("collection_1")
|
||||
|
||||
vector._client.get_or_create_collection.assert_called_once_with("collection_1")
|
||||
chroma_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_with_empty_texts_is_noop(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
vector.create([], [])
|
||||
vector._client.get_or_create_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_create_with_texts_creates_collection_and_upserts(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
vector._client.get_or_create_collection.assert_called()
|
||||
vector._client.collection.upsert.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_methods_and_text_exists(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.collection.delete.assert_not_called()
|
||||
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector._client.collection.delete.assert_called_with(ids=["id-1"])
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}})
|
||||
|
||||
vector._client.collection.get.return_value = {"ids": ["id-1"]}
|
||||
assert vector.text_exists("id-1") is True
|
||||
vector._client.collection.get.return_value = {}
|
||||
assert vector.text_exists("id-2") is False
|
||||
|
||||
vector.delete()
|
||||
vector._client.delete_collection.assert_called_once_with("collection_1")
|
||||
|
||||
|
||||
def test_search_by_vector_handles_empty_results(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []}
|
||||
|
||||
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
|
||||
|
||||
|
||||
def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
vector._client.collection.query.return_value = {
|
||||
"ids": [["id-1", "id-2"]],
|
||||
"documents": [["doc high", "doc low"]],
|
||||
"metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]],
|
||||
"distances": [[0.1, 0.8]],
|
||||
}
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc high"
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
|
||||
def test_search_by_full_text_returns_empty_list(chroma_module):
|
||||
vector = chroma_module.ChromaVector(
|
||||
collection_name="collection_1",
|
||||
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
|
||||
)
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
|
||||
factory = chroma_module.ChromaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost")
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000)
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None)
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None)
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None)
|
||||
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None)
|
||||
|
||||
with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,927 @@
|
||||
import importlib
|
||||
import queue
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_clickzetta_module():
|
||||
clickzetta = types.ModuleType("clickzetta")
|
||||
|
||||
class _FakeCursor:
|
||||
def __init__(self):
|
||||
self.execute = MagicMock()
|
||||
self.executemany = MagicMock()
|
||||
self.fetchall = MagicMock(return_value=[])
|
||||
self.fetchone = MagicMock(return_value=(0,))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _FakeConnection:
|
||||
def __init__(self):
|
||||
self.cursor_obj = _FakeCursor()
|
||||
|
||||
def cursor(self):
|
||||
return self.cursor_obj
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
def connect(**_kwargs):
|
||||
return _FakeConnection()
|
||||
|
||||
clickzetta.connect = connect
|
||||
return clickzetta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clickzetta_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
|
||||
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.ClickzettaConfig(
|
||||
username="username",
|
||||
password="password",
|
||||
instance="instance",
|
||||
service="service",
|
||||
workspace="workspace",
|
||||
vcluster="cluster",
|
||||
schema_name="dify",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "error_message"),
|
||||
[
|
||||
("username", "CLICKZETTA_USERNAME"),
|
||||
("password", "CLICKZETTA_PASSWORD"),
|
||||
("instance", "CLICKZETTA_INSTANCE"),
|
||||
("service", "CLICKZETTA_SERVICE"),
|
||||
("workspace", "CLICKZETTA_WORKSPACE"),
|
||||
("vcluster", "CLICKZETTA_VCLUSTER"),
|
||||
("schema_name", "CLICKZETTA_SCHEMA"),
|
||||
],
|
||||
)
|
||||
def test_clickzetta_config_validation(clickzetta_module, field, error_message):
|
||||
values = _config(clickzetta_module).model_dump()
|
||||
values[field] = ""
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
clickzetta_module.ClickzettaConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
|
||||
parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1")
|
||||
assert parsed["doc_id"] == "row-1"
|
||||
assert parsed["document_id"] == "doc-1"
|
||||
|
||||
parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2")
|
||||
assert parsed_double["doc_id"] == "row-2"
|
||||
assert parsed_double["document_id"] == "doc-2"
|
||||
|
||||
parsed_fallback = vector._parse_metadata("not-json", "row-3")
|
||||
assert parsed_fallback["doc_id"] == "row-3"
|
||||
assert parsed_fallback["document_id"] == "row-3"
|
||||
|
||||
|
||||
def test_safe_doc_id_and_vector_format_helpers(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
|
||||
assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3"
|
||||
assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF"
|
||||
assert vector._safe_doc_id("ab c;\n") == "abc"
|
||||
assert len(vector._safe_doc_id("a" * 300)) == 255
|
||||
|
||||
|
||||
def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
|
||||
@contextmanager
|
||||
def _ctx_not_found():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found")
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_not_found
|
||||
assert vector._table_exists() is False
|
||||
|
||||
@contextmanager
|
||||
def _ctx_other_error():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.execute.side_effect = RuntimeError("permission denied")
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_other_error
|
||||
assert vector._table_exists() is False
|
||||
|
||||
|
||||
def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.fetchone.return_value = (1,)
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
assert vector.text_exists("doc-1") is True
|
||||
|
||||
|
||||
def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._execute_write = MagicMock()
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._execute_write.assert_not_called()
|
||||
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
vector._execute_write.assert_not_called()
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._execute_write.assert_not_called()
|
||||
|
||||
|
||||
def test_search_short_circuit_behaviors(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
|
||||
|
||||
vector._config.enable_inverted_index = False
|
||||
assert vector.search_by_full_text("query", top_k=2) == []
|
||||
|
||||
|
||||
def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')]
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "content"
|
||||
assert docs[0].metadata["score"] == 0.5
|
||||
|
||||
|
||||
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
|
||||
factory = clickzetta_module.ClickzettaVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10)
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True)
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart")
|
||||
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance")
|
||||
|
||||
with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
|
||||
|
||||
|
||||
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
|
||||
clickzetta_module.ClickzettaConnectionPool._instance = None
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
|
||||
pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance()
|
||||
pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance()
|
||||
key = pool_1._get_config_key(_config(clickzetta_module))
|
||||
|
||||
assert pool_1 is pool_2
|
||||
assert "username:instance:service:workspace:cluster:dify" in key
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
connection = MagicMock()
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
|
||||
monkeypatch.setattr(
|
||||
clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection])
|
||||
)
|
||||
pool._configure_connection = MagicMock()
|
||||
|
||||
created = pool._create_connection(config)
|
||||
|
||||
assert created is connection
|
||||
assert clickzetta_module.clickzetta.connect.call_count == 2
|
||||
pool._configure_connection.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
|
||||
monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
pool._create_connection(config)
|
||||
|
||||
|
||||
def test_connection_pool_configure_and_validate_connection(clickzetta_module):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection = MagicMock()
|
||||
connection.cursor.return_value = cursor
|
||||
|
||||
pool._configure_connection(connection)
|
||||
assert cursor.execute.call_count >= 2
|
||||
assert pool._is_connection_valid(connection) is True
|
||||
|
||||
bad_connection = MagicMock()
|
||||
bad_connection.cursor.side_effect = RuntimeError("bad connection")
|
||||
assert pool._is_connection_valid(bad_connection) is False
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_connection_pool_configure_connection_swallows_errors(clickzetta_module):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
connection = MagicMock()
|
||||
connection.cursor.side_effect = RuntimeError("cannot configure")
|
||||
|
||||
pool._configure_connection(connection)
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
key = pool._get_config_key(config)
|
||||
|
||||
created_connection = MagicMock()
|
||||
pool._create_connection = MagicMock(return_value=created_connection)
|
||||
first = pool.get_connection(config)
|
||||
assert first is created_connection
|
||||
|
||||
reusable_connection = MagicMock()
|
||||
pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())]
|
||||
pool._is_connection_valid = MagicMock(return_value=True)
|
||||
reused = pool.get_connection(config)
|
||||
assert reused is reusable_connection
|
||||
|
||||
expired_connection = MagicMock()
|
||||
pool._pools[key] = [(expired_connection, 0.0)]
|
||||
pool._is_connection_valid = MagicMock(return_value=False)
|
||||
monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0))
|
||||
pool.get_connection(config)
|
||||
expired_connection.close.assert_called_once()
|
||||
|
||||
random_connection = MagicMock()
|
||||
pool._is_connection_valid = MagicMock(return_value=True)
|
||||
pool.return_connection(config, random_connection)
|
||||
assert len(pool._pools[key]) == 1
|
||||
|
||||
pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)]
|
||||
pool._connection_timeout = 10
|
||||
pool._cleanup_expired_connections()
|
||||
assert len(pool._pools[key]) == 1
|
||||
|
||||
unknown_pool = MagicMock()
|
||||
pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool)
|
||||
unknown_pool.close.assert_called_once()
|
||||
|
||||
pool.shutdown()
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
|
||||
|
||||
class _Thread:
|
||||
def __init__(self, target, daemon):
|
||||
self._target = target
|
||||
self.daemon = daemon
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
self._target()
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
|
||||
pool._start_cleanup_thread()
|
||||
|
||||
assert pool._cleanup_thread.started is True
|
||||
pool._cleanup_expired_connections.assert_called_once()
|
||||
|
||||
|
||||
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
pool.get_connection.return_value = "conn"
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock())
|
||||
|
||||
vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module))
|
||||
assert vector._table_name == "my_collection"
|
||||
|
||||
assert vector._get_connection() == "conn"
|
||||
vector._return_connection("conn")
|
||||
pool.return_connection.assert_called_with(vector._config, "conn")
|
||||
|
||||
with vector.get_connection_context() as conn:
|
||||
assert conn == "conn"
|
||||
assert pool.return_connection.call_count >= 2
|
||||
|
||||
assert vector.get_type() == "clickzetta"
|
||||
assert vector._ensure_connection() == "conn"
|
||||
|
||||
|
||||
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
|
||||
class _Thread:
|
||||
def __init__(self, target, daemon):
|
||||
self.target = target
|
||||
self.daemon = daemon
|
||||
self.started = 0
|
||||
|
||||
def start(self):
|
||||
self.started += 1
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
|
||||
clickzetta_module.ClickzettaVector._write_queue = None
|
||||
clickzetta_module.ClickzettaVector._write_thread = None
|
||||
clickzetta_module.ClickzettaVector._shutdown = False
|
||||
clickzetta_module.ClickzettaVector._init_write_queue()
|
||||
clickzetta_module.ClickzettaVector._init_write_queue()
|
||||
assert clickzetta_module.ClickzettaVector._write_thread.started == 1
|
||||
|
||||
result_queue_ok = queue.Queue()
|
||||
result_queue_fail = queue.Queue()
|
||||
clickzetta_module.ClickzettaVector._write_queue = queue.Queue()
|
||||
clickzetta_module.ClickzettaVector._shutdown = False
|
||||
clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok))
|
||||
clickzetta_module.ClickzettaVector._write_queue.put(
|
||||
(lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail)
|
||||
)
|
||||
clickzetta_module.ClickzettaVector._write_queue.put(None)
|
||||
clickzetta_module.ClickzettaVector._write_worker()
|
||||
|
||||
assert result_queue_ok.get() == (True, 2)
|
||||
failed = result_queue_fail.get()
|
||||
assert failed[0] is False
|
||||
assert isinstance(failed[1], RuntimeError)
|
||||
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
clickzetta_module.ClickzettaVector._write_queue = None
|
||||
with pytest.raises(RuntimeError, match="Write queue not initialized"):
|
||||
vector._execute_write(lambda: None)
|
||||
|
||||
class _ImmediateSuccessQueue:
|
||||
def put(self, task):
|
||||
func, args, kwargs, result_q = task
|
||||
result_q.put((True, func(*args, **kwargs)))
|
||||
|
||||
clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue()
|
||||
assert vector._execute_write(lambda x: x * 2, 3) == 6
|
||||
|
||||
class _ImmediateFailQueue:
|
||||
def put(self, task):
|
||||
_, _, _, result_q = task
|
||||
result_q.put((False, ValueError("write failed")))
|
||||
|
||||
clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue()
|
||||
with pytest.raises(ValueError, match="write failed"):
|
||||
vector._execute_write(lambda: None)
|
||||
|
||||
|
||||
def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
|
||||
@contextmanager
|
||||
def _ctx_exists():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_exists
|
||||
assert vector._table_exists() is True
|
||||
|
||||
vector._execute_write = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="content", metadata={"doc_id": "d1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
vector._execute_write.assert_called_once()
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_table_and_indexes_paths(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._create_vector_index = MagicMock()
|
||||
vector._create_inverted_index = MagicMock()
|
||||
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
vector._create_table_and_indexes([[0.1, 0.2]])
|
||||
vector._create_vector_index.assert_not_called()
|
||||
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
vector._create_table_and_indexes([[0.1, 0.2, 0.3]])
|
||||
vector._create_vector_index.assert_called_once()
|
||||
vector._create_inverted_index.assert_called_once()
|
||||
|
||||
vector._config.enable_inverted_index = False
|
||||
vector._create_vector_index.reset_mock()
|
||||
vector._create_inverted_index.reset_mock()
|
||||
vector._create_table_and_indexes([])
|
||||
vector._create_vector_index.assert_called_once()
|
||||
vector._create_inverted_index.assert_not_called()
|
||||
|
||||
|
||||
def test_create_vector_index_branches(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
cursor = MagicMock()
|
||||
|
||||
cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")]
|
||||
vector._create_vector_index(cursor)
|
||||
assert cursor.execute.call_count == 1
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [RuntimeError("show index failed"), None]
|
||||
vector._create_vector_index(cursor)
|
||||
assert cursor.execute.call_count == 2
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [None, RuntimeError("already exists")]
|
||||
cursor.fetchall.return_value = []
|
||||
vector._create_vector_index(cursor)
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [None, RuntimeError("unexpected")]
|
||||
cursor.fetchall.return_value = []
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
vector._create_vector_index(cursor)
|
||||
|
||||
|
||||
def test_create_inverted_index_branches(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
cursor = MagicMock()
|
||||
|
||||
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
|
||||
vector._create_inverted_index(cursor)
|
||||
assert cursor.execute.call_count == 1
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [RuntimeError("show failed"), None]
|
||||
vector._create_inverted_index(cursor)
|
||||
assert cursor.execute.call_count == 2
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [
|
||||
None,
|
||||
RuntimeError("already has index"),
|
||||
None,
|
||||
]
|
||||
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
|
||||
vector._create_inverted_index(cursor)
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [None, RuntimeError("other create failure")]
|
||||
cursor.fetchall.return_value = []
|
||||
vector._create_inverted_index(cursor)
|
||||
|
||||
|
||||
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._config.batch_size = 2
|
||||
vector._table_name = "table_1"
|
||||
vector._execute_write = MagicMock()
|
||||
vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id))
|
||||
|
||||
docs = [
|
||||
Document(page_content="doc-1", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="doc-2", metadata={"doc_id": "id-2"}),
|
||||
Document(page_content="doc-3", metadata={"doc_id": "id-3"}),
|
||||
]
|
||||
vectors = [[0.1], [0.2], [0.3]]
|
||||
|
||||
vector.add_texts([], [])
|
||||
vector._execute_write.assert_not_called()
|
||||
|
||||
added_ids = vector.add_texts(docs, vectors)
|
||||
assert added_ids == ["id-1", "id-2", "id-3"]
|
||||
assert vector._execute_write.call_count == 2
|
||||
assert vector._execute_write.call_args_list[0].args == (
|
||||
vector._insert_batch,
|
||||
docs[:2],
|
||||
vectors[:2],
|
||||
["id-1", "id-2"],
|
||||
0,
|
||||
2,
|
||||
2,
|
||||
)
|
||||
assert vector._execute_write.call_args_list[1].args == (
|
||||
vector._insert_batch,
|
||||
docs[2:],
|
||||
vectors[2:],
|
||||
["id-3"],
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
)
|
||||
|
||||
vector._insert_batch([], [], [], 0, 2, 1)
|
||||
vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1)
|
||||
|
||||
bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}})
|
||||
good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"})
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
vector._insert_batch(
|
||||
[bad_doc, good_doc],
|
||||
[[0.1, 0.2], [0.3, 0.4]],
|
||||
["id-bad", "id-good"],
|
||||
0,
|
||||
2,
|
||||
1,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_error():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.executemany.side_effect = RuntimeError("insert failed")
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_error
|
||||
with pytest.raises(RuntimeError, match="insert failed"):
|
||||
vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1)
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id")
|
||||
vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector)
|
||||
assert vector._safe_doc_id("") == "generated-id"
|
||||
assert vector._safe_doc_id("!!!") == "generated-id"
|
||||
|
||||
|
||||
def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._execute_write = MagicMock()
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
vector._execute_write.assert_called_once()
|
||||
assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl
|
||||
|
||||
vector._execute_write.reset_mock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._execute_write.assert_called_once()
|
||||
assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl
|
||||
|
||||
vector._safe_doc_id = MagicMock(side_effect=lambda x: x)
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
vector._delete_by_ids_impl(["id-1", "id-2"])
|
||||
vector._delete_by_metadata_field_impl("document_id", "doc-1")
|
||||
|
||||
|
||||
def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._config.vector_distance_function = "cosine_distance"
|
||||
vector._table_name = "table_1"
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)]
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
cosine_docs = vector.search_by_vector(
|
||||
[0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"}
|
||||
)
|
||||
assert cosine_docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
vector._config.vector_distance_function = "l2_distance"
|
||||
l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
|
||||
assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2)
|
||||
|
||||
|
||||
def test_search_by_full_text_success_and_fallback(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_success():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.fetchall.return_value = [
|
||||
("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'),
|
||||
("seg-2", "content-2", "invalid-json"),
|
||||
]
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_success
|
||||
docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1})
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["score"] == 1.0
|
||||
assert docs[1].metadata["doc_id"] == "seg-2"
|
||||
|
||||
@contextmanager
|
||||
def _ctx_failure():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.execute.side_effect = RuntimeError("full text failed")
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx_failure
|
||||
vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})])
|
||||
fallback_docs = vector.search_by_full_text("query", top_k=1)
|
||||
assert fallback_docs == vector._search_by_like.return_value
|
||||
|
||||
|
||||
def test_search_by_like_missing_table_and_delete_table(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
assert vector._search_by_like("query", top_k=1) == []
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
vector.delete()
|
||||
|
||||
|
||||
def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._pools = {}
|
||||
pool._pool_locks = {}
|
||||
pool._max_pool_size = 1
|
||||
pool._connection_timeout = 10
|
||||
pool._lock = clickzetta_module.threading.Lock()
|
||||
pool._shutdown = False
|
||||
|
||||
config = _config(clickzetta_module)
|
||||
key = pool._get_config_key(config)
|
||||
pool._pools[key] = [(MagicMock(), 1.0)]
|
||||
pool._pool_locks[key] = clickzetta_module.threading.Lock()
|
||||
pool._is_connection_valid = MagicMock(return_value=False)
|
||||
|
||||
conn = MagicMock()
|
||||
pool.return_connection(config, conn)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)]
|
||||
pool._cleanup_expired_connections()
|
||||
pool.shutdown()
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
|
||||
def _cleanup_then_fail():
|
||||
pool._shutdown = True
|
||||
raise RuntimeError("cleanup failed")
|
||||
|
||||
pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail)
|
||||
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
|
||||
|
||||
class _Thread:
|
||||
def __init__(self, target, daemon):
|
||||
self._target = target
|
||||
self.daemon = daemon
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
|
||||
pool._start_cleanup_thread()
|
||||
pool._cleanup_expired_connections.assert_called_once()
|
||||
|
||||
|
||||
def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
|
||||
parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1")
|
||||
assert parsed_non_dict["doc_id"] == "row-1"
|
||||
assert parsed_non_dict["document_id"] == "row-1"
|
||||
|
||||
parsed_none = vector._parse_metadata(None, "row-2")
|
||||
assert parsed_none["doc_id"] == "row-2"
|
||||
assert parsed_none["document_id"] == "row-2"
|
||||
|
||||
clickzetta_module.ClickzettaVector._shutdown = False
|
||||
clickzetta_module.ClickzettaVector._write_queue = None
|
||||
clickzetta_module.ClickzettaVector._write_worker()
|
||||
|
||||
class _BadQueue:
|
||||
def get(self, timeout):
|
||||
clickzetta_module.ClickzettaVector._shutdown = True
|
||||
raise RuntimeError("queue failed")
|
||||
|
||||
clickzetta_module.ClickzettaVector._shutdown = False
|
||||
clickzetta_module.ClickzettaVector._write_queue = _BadQueue()
|
||||
clickzetta_module.ClickzettaVector._write_worker()
|
||||
|
||||
|
||||
def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._table_name = "table_1"
|
||||
cursor = MagicMock()
|
||||
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
|
||||
cursor.execute.side_effect = [
|
||||
None,
|
||||
RuntimeError("already has index with the same type cannot create inverted index"),
|
||||
None,
|
||||
]
|
||||
|
||||
vector._create_inverted_index(cursor)
|
||||
|
||||
vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value))
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor_obj = MagicMock()
|
||||
cursor_obj.__enter__.return_value = cursor_obj
|
||||
cursor_obj.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor_obj
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
vector._insert_batch(
|
||||
[SimpleNamespace(page_content="content", metadata="not-a-dict")],
|
||||
[[0.1, 0.2]],
|
||||
["doc-1"],
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._config.enable_inverted_index = True
|
||||
vector._table_name = "table_1"
|
||||
|
||||
vector._table_exists = MagicMock(return_value=False)
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
vector._table_exists = MagicMock(return_value=True)
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__.return_value = cursor
|
||||
cursor.__exit__.return_value = None
|
||||
cursor.fetchall.return_value = [
|
||||
("seg-1", "content-1", "[1,2,3]"),
|
||||
("seg-2", "content-2", None),
|
||||
]
|
||||
connection.cursor.return_value = cursor
|
||||
yield connection
|
||||
|
||||
vector.get_connection_context = _ctx
|
||||
docs = vector.search_by_full_text("query")
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["doc_id"] == "seg-1"
|
||||
assert docs[1].metadata["doc_id"] == "seg-2"
|
||||
@ -0,0 +1,364 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_couchbase_modules():
|
||||
couchbase = types.ModuleType("couchbase")
|
||||
couchbase_auth = types.ModuleType("couchbase.auth")
|
||||
couchbase_cluster = types.ModuleType("couchbase.cluster")
|
||||
couchbase_management = types.ModuleType("couchbase.management")
|
||||
couchbase_management_search = types.ModuleType("couchbase.management.search")
|
||||
couchbase_options = types.ModuleType("couchbase.options")
|
||||
couchbase_vector = types.ModuleType("couchbase.vector_search")
|
||||
couchbase_search = types.ModuleType("couchbase.search")
|
||||
|
||||
class PasswordAuthenticator:
|
||||
def __init__(self, user, password):
|
||||
self.user = user
|
||||
self.password = password
|
||||
|
||||
class ClusterOptions:
|
||||
def __init__(self, auth):
|
||||
self.auth = auth
|
||||
|
||||
class SearchOptions:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class VectorQuery:
|
||||
def __init__(self, field, vector, top_k):
|
||||
self.field = field
|
||||
self.vector = vector
|
||||
self.top_k = top_k
|
||||
|
||||
class VectorSearch:
|
||||
@staticmethod
|
||||
def from_vector_query(vector_query):
|
||||
return {"vector_query": vector_query}
|
||||
|
||||
class QueryStringQuery:
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
|
||||
class SearchRequest:
|
||||
@staticmethod
|
||||
def create(payload):
|
||||
return {"payload": payload}
|
||||
|
||||
class SearchIndex:
|
||||
def __init__(self, name, params, source_name):
|
||||
self.name = name
|
||||
self.params = params
|
||||
self.source_name = source_name
|
||||
|
||||
class _QueryResult:
|
||||
def __init__(self, rows=None):
|
||||
self._rows = rows or []
|
||||
|
||||
def execute(self):
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._rows)
|
||||
|
||||
class _SearchIter:
|
||||
def __init__(self, rows=None):
|
||||
self._rows = rows or []
|
||||
|
||||
def rows(self):
|
||||
return self._rows
|
||||
|
||||
class _Collection:
|
||||
def __init__(self):
|
||||
self.upsert = MagicMock(return_value=True)
|
||||
|
||||
class _SearchIndexManager:
|
||||
def __init__(self):
|
||||
self.upsert_index = MagicMock()
|
||||
|
||||
class _Scope:
|
||||
def __init__(self):
|
||||
self._collection = _Collection()
|
||||
self._search_index_manager = _SearchIndexManager()
|
||||
self.search = MagicMock(return_value=_SearchIter())
|
||||
|
||||
def collection(self, _name):
|
||||
return self._collection
|
||||
|
||||
def search_indexes(self):
|
||||
return self._search_index_manager
|
||||
|
||||
class _CollectionManager:
|
||||
def __init__(self):
|
||||
self.create_collection = MagicMock()
|
||||
self.drop_collection = MagicMock()
|
||||
self.get_all_scopes = MagicMock(return_value=[])
|
||||
|
||||
class _Bucket:
|
||||
def __init__(self):
|
||||
self._scope = _Scope()
|
||||
self._collections = _CollectionManager()
|
||||
|
||||
def scope(self, _scope_name):
|
||||
return self._scope
|
||||
|
||||
def collections(self):
|
||||
return self._collections
|
||||
|
||||
class Cluster:
|
||||
def __init__(self, connection_string, options):
|
||||
self.connection_string = connection_string
|
||||
self.options = options
|
||||
self._bucket = _Bucket()
|
||||
self.wait_until_ready = MagicMock()
|
||||
self.query = MagicMock(return_value=_QueryResult())
|
||||
|
||||
def bucket(self, _name):
|
||||
return self._bucket
|
||||
|
||||
couchbase_auth.PasswordAuthenticator = PasswordAuthenticator
|
||||
couchbase_cluster.Cluster = Cluster
|
||||
couchbase_management_search.SearchIndex = SearchIndex
|
||||
couchbase_options.ClusterOptions = ClusterOptions
|
||||
couchbase_options.SearchOptions = SearchOptions
|
||||
couchbase_vector.VectorQuery = VectorQuery
|
||||
couchbase_vector.VectorSearch = VectorSearch
|
||||
couchbase_search.QueryStringQuery = QueryStringQuery
|
||||
couchbase_search.SearchRequest = SearchRequest
|
||||
|
||||
couchbase.search = couchbase_search
|
||||
couchbase.management = couchbase_management
|
||||
|
||||
return {
|
||||
"couchbase": couchbase,
|
||||
"couchbase.auth": couchbase_auth,
|
||||
"couchbase.cluster": couchbase_cluster,
|
||||
"couchbase.management": couchbase_management,
|
||||
"couchbase.management.search": couchbase_management_search,
|
||||
"couchbase.options": couchbase_options,
|
||||
"couchbase.vector_search": couchbase_vector,
|
||||
"couchbase.search": couchbase_search,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def couchbase_module(monkeypatch):
|
||||
for name, module in _build_fake_couchbase_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.CouchbaseConfig(
|
||||
connection_string="couchbase://localhost",
|
||||
user="user",
|
||||
password="pass",
|
||||
bucket_name="bucket",
|
||||
scope_name="scope",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("connection_string", "", "CONNECTION_STRING is required"),
|
||||
("user", "", "COUCHBASE_USER is required"),
|
||||
("password", "", "COUCHBASE_PASSWORD is required"),
|
||||
("bucket_name", "", "COUCHBASE_PASSWORD is required"),
|
||||
("scope_name", "", "COUCHBASE_SCOPE_NAME is required"),
|
||||
],
|
||||
)
|
||||
def test_couchbase_config_validation(couchbase_module, field, value, message):
|
||||
values = _config(couchbase_module).model_dump()
|
||||
values[field] = value
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
couchbase_module.CouchbaseConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_sets_cluster_handles(couchbase_module):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
|
||||
assert vector._bucket_name == "bucket"
|
||||
assert vector._scope_name == "scope"
|
||||
vector._cluster.wait_until_ready.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
|
||||
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = _config(couchbase_module)
|
||||
vector._scope_name = "scope"
|
||||
vector._bucket_name = "bucket"
|
||||
vector._bucket = MagicMock()
|
||||
vector._scope = MagicMock()
|
||||
vector._collection_exists = MagicMock(return_value=False)
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c")
|
||||
vector._create_collection = MagicMock()
|
||||
docs = [Document(page_content="text", metadata={"doc_id": "id-1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(vector_length=2, uuid="uuid-1")
|
||||
vector._bucket.collections().create_collection.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._collection_exists = MagicMock(return_value=True)
|
||||
vector._create_collection(vector_length=2, uuid="uuid-2")
|
||||
vector._bucket.collections().create_collection.assert_not_called()
|
||||
|
||||
vector._collection_exists = MagicMock(return_value=False)
|
||||
vector._create_collection(vector_length=3, uuid="uuid-3")
|
||||
|
||||
vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1")
|
||||
vector._scope.search_indexes().upsert_index.assert_called_once()
|
||||
search_index = vector._scope.search_indexes().upsert_index.call_args.args[0]
|
||||
assert search_index.name == "collection_1_search"
|
||||
assert (
|
||||
search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"]
|
||||
== 3
|
||||
)
|
||||
couchbase_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_collection_exists_get_type_and_add_texts(couchbase_module):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
|
||||
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")])
|
||||
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
|
||||
assert vector._collection_exists("collection_1") is True
|
||||
|
||||
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")])
|
||||
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
|
||||
assert vector._collection_exists("collection_1") is False
|
||||
|
||||
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert vector._scope.collection("collection_1").upsert.call_count == 2
|
||||
assert vector.get_type() == couchbase_module.VectorType.COUCHBASE
|
||||
|
||||
|
||||
def test_query_delete_helpers(couchbase_module):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
|
||||
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}]))
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([]))
|
||||
assert vector.text_exists("id-2") is False
|
||||
|
||||
query_result = MagicMock()
|
||||
query_result.execute.return_value = None
|
||||
vector._cluster.query.return_value = query_result
|
||||
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
vector.delete_by_document_id("id-1")
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
assert vector._cluster.query.call_count >= 3
|
||||
|
||||
vector._cluster.query.side_effect = RuntimeError("delete failed")
|
||||
vector.delete_by_ids(["id-3"])
|
||||
|
||||
|
||||
def test_search_methods_and_format_metadata(couchbase_module):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
|
||||
row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9)
|
||||
row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3)
|
||||
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc-a"
|
||||
assert docs[0].metadata["document_id"] == "d-1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
vector._scope.search.side_effect = RuntimeError("search error")
|
||||
with pytest.raises(ValueError, match="Search failed"):
|
||||
vector.search_by_vector([0.1], top_k=1)
|
||||
|
||||
vector._scope.search.side_effect = None
|
||||
row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7)
|
||||
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3])
|
||||
docs = vector.search_by_full_text("hello", top_k=1)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == "x"
|
||||
|
||||
vector._scope.search.side_effect = RuntimeError("full text failed")
|
||||
with pytest.raises(ValueError, match="Search failed"):
|
||||
vector.search_by_full_text("hello", top_k=1)
|
||||
|
||||
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
|
||||
|
||||
|
||||
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
scopes = [
|
||||
SimpleNamespace(collections=[SimpleNamespace(name="other")]),
|
||||
SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]),
|
||||
]
|
||||
vector._bucket.collections().get_all_scopes.return_value = scopes
|
||||
|
||||
vector.delete()
|
||||
vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1")
|
||||
|
||||
factory = couchbase_module.CouchbaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(
|
||||
couchbase_module,
|
||||
"current_app",
|
||||
SimpleNamespace(
|
||||
config={
|
||||
"COUCHBASE_CONNECTION_STRING": "couchbase://localhost",
|
||||
"COUCHBASE_USER": "user",
|
||||
"COUCHBASE_PASSWORD": "pass",
|
||||
"COUCHBASE_BUCKET_NAME": "bucket",
|
||||
"COUCHBASE_SCOPE_NAME": "scope",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,121 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _build_fake_elasticsearch_modules():
|
||||
elasticsearch = types.ModuleType("elasticsearch")
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
class Elasticsearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.ping = MagicMock(return_value=True)
|
||||
self.info = MagicMock(return_value={"version": {"number": "8.12.0"}})
|
||||
self.indices = SimpleNamespace(
|
||||
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
|
||||
)
|
||||
|
||||
elasticsearch.Elasticsearch = Elasticsearch
|
||||
elasticsearch.ConnectionError = ConnectionError
|
||||
return {"elasticsearch": elasticsearch}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_ja_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
|
||||
|
||||
importlib.reload(base_module)
|
||||
return importlib.reload(ja_module)
|
||||
|
||||
|
||||
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
|
||||
vector._collection_name = "test"
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
|
||||
vector._client.indices.create.assert_not_called()
|
||||
elasticsearch_ja_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
|
||||
vector._collection_name = "test"
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.create_collection([[0.1, 0.2, 0.3]], [{}])
|
||||
|
||||
vector._client.indices.create.assert_called_once()
|
||||
kwargs = vector._client.indices.create.call_args.kwargs
|
||||
assert kwargs["index"] == "test"
|
||||
assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3
|
||||
elasticsearch_ja_module.redis_client.set.assert_called_once()
|
||||
|
||||
vector._client.indices.create.reset_mock()
|
||||
elasticsearch_ja_module.redis_client.set.reset_mock()
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
|
||||
vector._client.indices.create.assert_not_called()
|
||||
elasticsearch_ja_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
|
||||
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(
|
||||
elasticsearch_ja_module,
|
||||
"current_app",
|
||||
SimpleNamespace(
|
||||
config={
|
||||
"ELASTICSEARCH_HOST": "localhost",
|
||||
"ELASTICSEARCH_PORT": 9200,
|
||||
"ELASTICSEARCH_USERNAME": "elastic",
|
||||
"ELASTICSEARCH_PASSWORD": "secret",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,405 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_elasticsearch_modules():
|
||||
elasticsearch = types.ModuleType("elasticsearch")
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
class Elasticsearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.ping = MagicMock(return_value=True)
|
||||
self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}})
|
||||
self.index = MagicMock()
|
||||
self.exists = MagicMock(return_value=False)
|
||||
self.delete = MagicMock()
|
||||
self.search = MagicMock(return_value={"hits": {"hits": []}})
|
||||
self.indices = SimpleNamespace(
|
||||
refresh=MagicMock(),
|
||||
delete=MagicMock(),
|
||||
exists=MagicMock(return_value=False),
|
||||
create=MagicMock(),
|
||||
)
|
||||
|
||||
elasticsearch.Elasticsearch = Elasticsearch
|
||||
elasticsearch.ConnectionError = ConnectionError
|
||||
return {"elasticsearch": elasticsearch}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _regular_config(module, **overrides):
|
||||
values = {
|
||||
"host": "localhost",
|
||||
"port": 9200,
|
||||
"username": "elastic",
|
||||
"password": "secret",
|
||||
"verify_certs": False,
|
||||
"request_timeout": 10,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 3,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.ElasticSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def _cloud_config(module, **overrides):
|
||||
values = {
|
||||
"use_cloud": True,
|
||||
"cloud_url": "https://cloud.example:9243",
|
||||
"api_key": "api-key",
|
||||
"verify_certs": True,
|
||||
"ca_certs": "/tmp/ca.pem",
|
||||
"request_timeout": 10,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 3,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.ElasticSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("values", "message"),
|
||||
[
|
||||
({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"),
|
||||
({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"),
|
||||
({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"),
|
||||
({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"),
|
||||
({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"),
|
||||
({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"),
|
||||
],
|
||||
)
|
||||
def test_elasticsearch_config_validation(elasticsearch_module, values, message):
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
elasticsearch_module.ElasticSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_client_cloud_configuration(elasticsearch_module):
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
client = MagicMock()
|
||||
client.ping.return_value = True
|
||||
|
||||
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
|
||||
result = vector._init_client(_cloud_config(elasticsearch_module))
|
||||
|
||||
assert result is client
|
||||
kwargs = es_cls.call_args.kwargs
|
||||
assert kwargs["hosts"] == ["https://cloud.example:9243"]
|
||||
assert kwargs["api_key"] == "api-key"
|
||||
assert kwargs["verify_certs"] is True
|
||||
assert kwargs["ca_certs"] == "/tmp/ca.pem"
|
||||
|
||||
|
||||
def test_init_client_regular_https_and_http_fallback(elasticsearch_module):
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
client = MagicMock()
|
||||
client.ping.return_value = True
|
||||
|
||||
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
|
||||
vector._init_client(
|
||||
_regular_config(
|
||||
elasticsearch_module,
|
||||
host="https://es.example",
|
||||
port=9443,
|
||||
verify_certs=True,
|
||||
ca_certs="/tmp/ca.pem",
|
||||
)
|
||||
)
|
||||
kwargs = es_cls.call_args.kwargs
|
||||
assert kwargs["hosts"] == ["https://es.example:9443"]
|
||||
assert kwargs["verify_certs"] is True
|
||||
assert kwargs["ca_certs"] == "/tmp/ca.pem"
|
||||
|
||||
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
|
||||
vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200))
|
||||
kwargs = es_cls.call_args.kwargs
|
||||
assert kwargs["hosts"] == ["http://es.internal:9200"]
|
||||
assert "verify_certs" not in kwargs
|
||||
|
||||
|
||||
def test_init_client_connection_failures(elasticsearch_module):
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
|
||||
client = MagicMock()
|
||||
client.ping.return_value = False
|
||||
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect"):
|
||||
vector._init_client(_regular_config(elasticsearch_module))
|
||||
|
||||
with patch.object(
|
||||
elasticsearch_module,
|
||||
"Elasticsearch",
|
||||
side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"),
|
||||
):
|
||||
with pytest.raises(ConnectionError, match="Vector database connection error"):
|
||||
vector._init_client(_regular_config(elasticsearch_module))
|
||||
|
||||
with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")):
|
||||
with pytest.raises(ConnectionError, match="initialization failed"):
|
||||
vector._init_client(_regular_config(elasticsearch_module))
|
||||
|
||||
|
||||
def test_init_get_version_and_check_version(elasticsearch_module):
|
||||
with (
|
||||
patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client,
|
||||
patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version,
|
||||
patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version,
|
||||
):
|
||||
vector = elasticsearch_module.ElasticSearchVector(
|
||||
"collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"]
|
||||
)
|
||||
|
||||
init_client.assert_called_once()
|
||||
get_version.assert_called_once()
|
||||
check_version.assert_called_once()
|
||||
assert vector._attributes == ["doc_id"]
|
||||
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
vector._client = MagicMock()
|
||||
vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}}
|
||||
assert vector._get_version() == "8.13.2"
|
||||
|
||||
vector._version = "7.17.0"
|
||||
with pytest.raises(ValueError, match="greater than 8.0.0"):
|
||||
vector._check_version()
|
||||
|
||||
vector._version = "8.0.0"
|
||||
vector._check_version()
|
||||
|
||||
|
||||
def test_crud_methods_and_get_type(elasticsearch_module):
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock())
|
||||
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert vector._client.index.call_count == 2
|
||||
vector._client.indices.refresh.assert_called_once_with(index="collection_1")
|
||||
|
||||
vector._client.exists.return_value = True
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.delete.assert_not_called()
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
assert vector._client.delete.call_count == 2
|
||||
|
||||
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("doc_id", "d1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-1"])
|
||||
|
||||
vector.delete_by_ids.reset_mock()
|
||||
vector._client.search.return_value = {"hits": {"hits": []}}
|
||||
vector.delete_by_metadata_field("doc_id", "d2")
|
||||
vector.delete_by_ids.assert_not_called()
|
||||
|
||||
vector.delete()
|
||||
vector._client.indices.delete.assert_called_once_with(index="collection_1")
|
||||
assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH
|
||||
|
||||
|
||||
def test_search_by_vector_and_full_text(elasticsearch_module):
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_score": 0.8,
|
||||
"_source": {
|
||||
elasticsearch_module.Field.CONTENT_KEY: "doc-a",
|
||||
elasticsearch_module.Field.VECTOR: [0.1],
|
||||
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"_score": 0.2,
|
||||
"_source": {
|
||||
elasticsearch_module.Field.CONTENT_KEY: "doc-b",
|
||||
elasticsearch_module.Field.VECTOR: [0.2],
|
||||
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
top_k=2,
|
||||
score_threshold=0.5,
|
||||
document_ids_filter=["d-1", "d-2"],
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
knn = vector._client.search.call_args.kwargs["knn"]
|
||||
assert knn["k"] == 2
|
||||
assert knn["num_candidates"] == 3
|
||||
assert "filter" in knn
|
||||
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
elasticsearch_module.Field.CONTENT_KEY: "text-hit",
|
||||
elasticsearch_module.Field.VECTOR: [0.3],
|
||||
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "text-hit"
|
||||
query = vector._client.search.call_args.kwargs["query"]
|
||||
assert "bool" in query
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
|
||||
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
|
||||
vector.create(docs, [[0.1]])
|
||||
vector.create_collection.assert_called_once()
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1]])
|
||||
|
||||
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
|
||||
|
||||
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
vector._client.indices.create.assert_called_once()
|
||||
mappings = vector._client.indices.create.call_args.kwargs["mappings"]
|
||||
assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2
|
||||
elasticsearch_module.redis_client.set.assert_called_once()
|
||||
|
||||
vector._client.indices.create.reset_mock()
|
||||
elasticsearch_module.redis_client.set.reset_mock()
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
elasticsearch_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
|
||||
factory = elasticsearch_module.ElasticSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
|
||||
monkeypatch.setattr(
|
||||
elasticsearch_module,
|
||||
"current_app",
|
||||
SimpleNamespace(
|
||||
config={
|
||||
"ELASTICSEARCH_USE_CLOUD": False,
|
||||
"ELASTICSEARCH_HOST": "es-host",
|
||||
"ELASTICSEARCH_PORT": 9200,
|
||||
"ELASTICSEARCH_USERNAME": "elastic",
|
||||
"ELASTICSEARCH_PASSWORD": "secret",
|
||||
"ELASTICSEARCH_VERIFY_CERTS": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
assert result_1 == "vector"
|
||||
cfg = vector_cls.call_args.kwargs["config"]
|
||||
assert cfg.use_cloud is False
|
||||
assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION"
|
||||
|
||||
monkeypatch.setattr(
|
||||
elasticsearch_module,
|
||||
"current_app",
|
||||
SimpleNamespace(
|
||||
config={
|
||||
"ELASTICSEARCH_USE_CLOUD": True,
|
||||
"ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic",
|
||||
"ELASTICSEARCH_API_KEY": "api-key",
|
||||
"ELASTICSEARCH_VERIFY_CERTS": True,
|
||||
}
|
||||
),
|
||||
)
|
||||
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
assert result_2 == "vector"
|
||||
cfg = vector_cls.call_args.kwargs["config"]
|
||||
assert cfg.use_cloud is True
|
||||
assert cfg.cloud_url == "https://cloud.elastic"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
|
||||
monkeypatch.setattr(
|
||||
elasticsearch_module,
|
||||
"current_app",
|
||||
SimpleNamespace(
|
||||
config={
|
||||
"ELASTICSEARCH_USE_CLOUD": True,
|
||||
"ELASTICSEARCH_CLOUD_URL": None,
|
||||
"ELASTICSEARCH_HOST": "fallback-host",
|
||||
"ELASTICSEARCH_PORT": 9201,
|
||||
"ELASTICSEARCH_USERNAME": "elastic",
|
||||
"ELASTICSEARCH_PASSWORD": "secret",
|
||||
}
|
||||
),
|
||||
)
|
||||
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
|
||||
factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
cfg = vector_cls.call_args.kwargs["config"]
|
||||
assert cfg.use_cloud is False
|
||||
assert cfg.host == "fallback-host"
|
||||
@ -0,0 +1,371 @@
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_hologres_modules():
|
||||
holo_module = types.ModuleType("holo_search_sdk")
|
||||
holo_types_module = types.ModuleType("holo_search_sdk.types")
|
||||
|
||||
holo_types_module.BaseQuantizationType = str
|
||||
holo_types_module.DistanceType = str
|
||||
holo_types_module.TokenizerType = str
|
||||
|
||||
def _connect(**kwargs):
|
||||
client = MagicMock()
|
||||
client.kwargs = kwargs
|
||||
client.connect = MagicMock()
|
||||
client.check_table_exist = MagicMock(return_value=False)
|
||||
client.open_table = MagicMock(return_value=MagicMock())
|
||||
client.execute = MagicMock(return_value=[])
|
||||
client.drop_table = MagicMock()
|
||||
return client
|
||||
|
||||
holo_module.connect = MagicMock(side_effect=_connect)
|
||||
|
||||
return {
|
||||
"holo_search_sdk": holo_module,
|
||||
"holo_search_sdk.types": holo_types_module,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hologres_module(monkeypatch):
|
||||
for name, module in _build_fake_hologres_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.hologres.hologres_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _valid_config(module):
|
||||
return module.HologresVectorConfig(
|
||||
host="localhost",
|
||||
port=80,
|
||||
database="dify",
|
||||
access_key_id="ak",
|
||||
access_key_secret="sk",
|
||||
schema_name="public",
|
||||
tokenizer="jieba",
|
||||
distance_method="Cosine",
|
||||
base_quantization_type="rabitq",
|
||||
max_degree=64,
|
||||
ef_construction=400,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config HOLOGRES_HOST is required"),
|
||||
("database", "", "config HOLOGRES_DATABASE is required"),
|
||||
("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"),
|
||||
("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"),
|
||||
],
|
||||
)
|
||||
def test_hologres_config_validation(hologres_module, field, value, message):
|
||||
values = _valid_config(hologres_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
hologres_module.HologresVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_client_and_get_type(hologres_module):
|
||||
vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module))
|
||||
|
||||
hologres_module.holo.connect.assert_called_once_with(
|
||||
host="localhost",
|
||||
port=80,
|
||||
database="dify",
|
||||
access_key_id="ak",
|
||||
access_key_secret="sk",
|
||||
schema="public",
|
||||
)
|
||||
vector._client.connect.assert_called_once()
|
||||
assert vector.table_name == "embedding_collection_one"
|
||||
assert vector.get_type() == hologres_module.VectorType.HOLOGRES
|
||||
|
||||
|
||||
def test_create_delegates_collection_creation_and_upsert(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
result = vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert result is None
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_returns_empty_for_empty_documents(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
|
||||
assert vector.add_texts([], []) == []
|
||||
vector._client.open_table.assert_not_called()
|
||||
|
||||
|
||||
def test_add_texts_batches_and_serializes_metadata(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
table = vector._client.open_table.return_value
|
||||
documents = [
|
||||
Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"})
|
||||
for i in range(100)
|
||||
]
|
||||
documents.append(SimpleNamespace(page_content="doc-100", metadata=None))
|
||||
embeddings = [[float(i)] for i in range(len(documents))]
|
||||
|
||||
ids = vector.add_texts(documents, embeddings)
|
||||
|
||||
assert ids[:2] == ["id-0", "id-1"]
|
||||
assert ids[-1] == ""
|
||||
assert len(ids) == 101
|
||||
assert vector._client.open_table.call_count == 2
|
||||
assert table.upsert_multi.call_count == 2
|
||||
first_call = table.upsert_multi.call_args_list[0].kwargs
|
||||
second_call = table.upsert_multi.call_args_list[1].kwargs
|
||||
assert first_call["index_column"] == "id"
|
||||
assert first_call["column_names"] == ["id", "text", "meta", "embedding"]
|
||||
assert first_call["update_columns"] == ["text", "meta", "embedding"]
|
||||
assert len(first_call["values"]) == 100
|
||||
assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"}
|
||||
assert second_call["values"][0][0] == ""
|
||||
assert second_call["values"][0][2] == "{}"
|
||||
|
||||
|
||||
def test_text_exists_handles_missing_and_present_tables(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.side_effect = [False, True]
|
||||
vector._client.execute.return_value = [(1,)]
|
||||
|
||||
assert vector.text_exists("seg-1") is False
|
||||
assert vector.text_exists("seg-1") is True
|
||||
vector._client.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []]
|
||||
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
|
||||
|
||||
|
||||
def test_delete_by_ids_branches(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.check_table_exist.assert_not_called()
|
||||
|
||||
vector._client.check_table_exist.return_value = False
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector._client.execute.assert_not_called()
|
||||
|
||||
vector._client.check_table_exist.return_value = True
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
vector._client.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_by_metadata_field_branches(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.return_value = False
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.execute.assert_not_called()
|
||||
|
||||
vector._client.check_table_exist.return_value = True
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_search_by_vector_returns_empty_when_table_missing(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.return_value = False
|
||||
|
||||
assert vector.search_by_vector([0.1, 0.2]) == []
|
||||
|
||||
|
||||
def test_search_by_vector_applies_filter_and_processes_results(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.return_value = True
|
||||
table = vector._client.open_table.return_value
|
||||
query = MagicMock()
|
||||
table.search_vector.return_value = query
|
||||
query.select.return_value = query
|
||||
query.limit.return_value = query
|
||||
query.where.return_value = query
|
||||
query.fetchall.return_value = [
|
||||
(0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'),
|
||||
(0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}),
|
||||
]
|
||||
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
top_k=2,
|
||||
score_threshold=0.5,
|
||||
document_ids_filter=["doc-1"],
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc-1"
|
||||
assert docs[0].metadata["doc_id"] == "seg-1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
table.search_vector.assert_called_once()
|
||||
query.where.assert_called_once()
|
||||
|
||||
|
||||
def test_search_by_full_text_returns_empty_when_table_missing(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.return_value = False
|
||||
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_search_by_full_text_applies_filter_and_processes_results(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.return_value = True
|
||||
table = vector._client.open_table.return_value
|
||||
search_query = MagicMock()
|
||||
table.search_text.return_value = search_query
|
||||
search_query.limit.return_value = search_query
|
||||
search_query.where.return_value = search_query
|
||||
search_query.fetchall.return_value = [
|
||||
("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95),
|
||||
("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7),
|
||||
]
|
||||
|
||||
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["doc_id"] == "seg-1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.95)
|
||||
assert docs[1].metadata["score"] == pytest.approx(0.7)
|
||||
table.search_text.assert_called_once()
|
||||
search_query.where.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_handles_existing_and_missing_tables(hologres_module):
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.side_effect = [False, True]
|
||||
|
||||
vector.delete()
|
||||
vector._client.drop_table.assert_not_called()
|
||||
|
||||
vector.delete()
|
||||
vector._client.drop_table.assert_called_once_with(vector.table_name)
|
||||
|
||||
|
||||
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._create_collection(3)
|
||||
|
||||
vector._client.check_table_exist.assert_not_called()
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
|
||||
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
|
||||
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.side_effect = [False, False, True]
|
||||
table = vector._client.open_table.return_value
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
vector._client.execute.assert_called_once()
|
||||
table.set_vector_index.assert_called_once_with(
|
||||
column="embedding",
|
||||
distance_method="Cosine",
|
||||
base_quantization_type="rabitq",
|
||||
max_degree=64,
|
||||
ef_construction=400,
|
||||
use_reorder=True,
|
||||
)
|
||||
table.create_text_index.assert_called_once_with(
|
||||
index_name="ft_idx_collection_one",
|
||||
column="text",
|
||||
tokenizer="jieba",
|
||||
)
|
||||
hologres_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
|
||||
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
|
||||
|
||||
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
|
||||
vector._client.check_table_exist.side_effect = [False] + [False] * 15
|
||||
|
||||
with pytest.raises(RuntimeError, match="was not ready after 30s"):
|
||||
vector._create_collection(3)
|
||||
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
|
||||
factory = hologres_module.HologresVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80)
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq")
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64)
|
||||
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400)
|
||||
|
||||
with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection"
|
||||
generated_config = vector_cls.call_args_list[1].kwargs["config"]
|
||||
assert generated_config.host == "127.0.0.1"
|
||||
assert generated_config.database == "dify"
|
||||
assert generated_config.access_key_id == "ak"
|
||||
assert json.loads(dataset_without_index.index_struct) == {
|
||||
"type": hologres_module.VectorType.HOLOGRES,
|
||||
"vector_store": {"class_prefix": "generated_collection"},
|
||||
}
|
||||
@ -0,0 +1,243 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_elasticsearch_modules():
|
||||
elasticsearch = types.ModuleType("elasticsearch")
|
||||
|
||||
class Elasticsearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.index = MagicMock()
|
||||
self.exists = MagicMock(return_value=False)
|
||||
self.delete = MagicMock()
|
||||
self.search = MagicMock(return_value={"hits": {"hits": []}})
|
||||
self.indices = SimpleNamespace(
|
||||
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
|
||||
)
|
||||
|
||||
elasticsearch.Elasticsearch = Elasticsearch
|
||||
return {"elasticsearch": elasticsearch}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huawei_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass")
|
||||
|
||||
|
||||
def test_create_ssl_context(huawei_module):
|
||||
ctx = huawei_module.create_ssl_context()
|
||||
assert ctx.check_hostname is False
|
||||
assert ctx.verify_mode == huawei_module.ssl.CERT_NONE
|
||||
|
||||
|
||||
def test_huawei_config_validation_and_params(huawei_module):
|
||||
with pytest.raises(ValidationError, match="HOSTS is required"):
|
||||
huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""})
|
||||
|
||||
config = _config(huawei_module)
|
||||
params = config.to_elasticsearch_params()
|
||||
assert params["hosts"] == ["http://localhost:9200"]
|
||||
assert params["basic_auth"] == ("user", "pass")
|
||||
|
||||
config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None)
|
||||
params = config.to_elasticsearch_params()
|
||||
assert "basic_auth" not in params
|
||||
|
||||
|
||||
def test_init_get_type_and_add_texts(huawei_module):
|
||||
vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module))
|
||||
|
||||
assert vector._collection_name == "collection"
|
||||
assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD
|
||||
|
||||
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert vector._client.index.call_count == 2
|
||||
vector._client.indices.refresh.assert_called_once_with(index="collection")
|
||||
|
||||
|
||||
def test_crud_methods(huawei_module):
|
||||
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
|
||||
|
||||
vector._client.exists.return_value = True
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.delete.assert_not_called()
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector._client.delete.assert_called_once_with(index="collection", id="id-1")
|
||||
|
||||
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("doc_id", "x")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-1"])
|
||||
|
||||
vector.delete_by_ids.reset_mock()
|
||||
vector._client.search.return_value = {"hits": {"hits": []}}
|
||||
vector.delete_by_metadata_field("doc_id", "x")
|
||||
vector.delete_by_ids.assert_not_called()
|
||||
|
||||
vector.delete()
|
||||
vector._client.indices.delete.assert_called_once_with(index="collection")
|
||||
|
||||
|
||||
def test_search_by_vector_and_full_text(huawei_module):
|
||||
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
huawei_module.Field.CONTENT_KEY: "doc-a",
|
||||
huawei_module.Field.VECTOR: [0.1],
|
||||
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"_score": 0.1,
|
||||
"_source": {
|
||||
huawei_module.Field.CONTENT_KEY: "doc-b",
|
||||
huawei_module.Field.VECTOR: [0.2],
|
||||
huawei_module.Field.METADATA_KEY: {"doc_id": "2"},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
query_body = vector._client.search.call_args.kwargs["body"]
|
||||
assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2
|
||||
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
huawei_module.Field.CONTENT_KEY: "text-hit",
|
||||
huawei_module.Field.VECTOR: [0.3],
|
||||
huawei_module.Field.METADATA_KEY: {"doc_id": "3"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
docs = vector.search_by_full_text("hello", top_k=3)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "text-hit"
|
||||
|
||||
|
||||
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
|
||||
class FakeDocument:
|
||||
def __init__(self, page_content, vector, metadata):
|
||||
self.page_content = page_content
|
||||
self.vector = vector
|
||||
self.metadata = None
|
||||
|
||||
monkeypatch.setattr(huawei_module, "Document", FakeDocument)
|
||||
|
||||
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
huawei_module.Field.CONTENT_KEY: "doc-a",
|
||||
huawei_module.Field.VECTOR: [0.1],
|
||||
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5)
|
||||
|
||||
assert docs == []
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
|
||||
vector.create(docs, [[0.1]])
|
||||
vector.create_collection.assert_called_once()
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1]])
|
||||
|
||||
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
|
||||
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.create_collection([[0.1, 0.2]], [{}])
|
||||
vector._client.indices.create.assert_called_once()
|
||||
|
||||
kwargs = vector._client.indices.create.call_args.kwargs
|
||||
mappings = kwargs["mappings"]
|
||||
assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2
|
||||
assert kwargs["settings"] == {"index.vector": True}
|
||||
huawei_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_huawei_factory_branches(huawei_module, monkeypatch):
|
||||
factory = huawei_module.HuaweiCloudVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200")
|
||||
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user")
|
||||
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass")
|
||||
|
||||
with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,412 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_iris_module():
|
||||
iris = types.ModuleType("iris")
|
||||
|
||||
def connect(**_kwargs):
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = MagicMock()
|
||||
return conn
|
||||
|
||||
iris.connect = MagicMock(side_effect=connect)
|
||||
return iris
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def iris_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
|
||||
|
||||
import core.rag.datasource.vdb.iris.iris_vector as module
|
||||
|
||||
reloaded = importlib.reload(module)
|
||||
reloaded._pool_instance = None
|
||||
return reloaded
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"IRIS_HOST": "localhost",
|
||||
"IRIS_SUPER_SERVER_PORT": 1972,
|
||||
"IRIS_USER": "user",
|
||||
"IRIS_PASSWORD": "pass",
|
||||
"IRIS_DATABASE": "db",
|
||||
"IRIS_SCHEMA": "schema",
|
||||
"IRIS_CONNECTION_URL": "url",
|
||||
"IRIS_MIN_CONNECTION": 1,
|
||||
"IRIS_MAX_CONNECTION": 2,
|
||||
"IRIS_TEXT_INDEX": True,
|
||||
"IRIS_TEXT_INDEX_LANGUAGE": "en",
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.IrisVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_get_iris_pool_singleton(iris_module):
|
||||
iris_module._pool_instance = None
|
||||
cfg = _config(iris_module)
|
||||
|
||||
with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls:
|
||||
pool_1 = iris_module.get_iris_pool(cfg)
|
||||
pool_2 = iris_module.get_iris_pool(cfg)
|
||||
|
||||
assert pool_1 == "pool"
|
||||
assert pool_2 == "pool"
|
||||
pool_cls.assert_called_once_with(cfg)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool_with_min_max(iris_module):
|
||||
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
|
||||
with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn:
|
||||
pool = iris_module.IrisConnectionPool(cfg)
|
||||
yield pool, create_conn
|
||||
|
||||
|
||||
def test_pool_initialization_respects_min_max(pool_with_min_max):
|
||||
pool, create_conn = pool_with_min_max
|
||||
assert len(pool._pool) == 2
|
||||
assert create_conn.call_count == 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool_for_get_connection(iris_module):
|
||||
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
|
||||
pool = iris_module.IrisConnectionPool(cfg)
|
||||
return pool
|
||||
|
||||
|
||||
def test_get_connection_returns_existing_and_increments(pool_for_get_connection):
|
||||
pool = pool_for_get_connection
|
||||
conn = MagicMock()
|
||||
pool._pool = [conn]
|
||||
pool._in_use = 0
|
||||
assert pool.get_connection() is conn
|
||||
assert pool._in_use == 1
|
||||
|
||||
|
||||
def test_get_connection_creates_new_when_empty(pool_for_get_connection):
|
||||
pool = pool_for_get_connection
|
||||
pool._pool = []
|
||||
pool._in_use = 0
|
||||
pool._create_connection = MagicMock(return_value="new-conn")
|
||||
assert pool.get_connection() == "new-conn"
|
||||
|
||||
|
||||
def test_get_connection_raises_when_exhausted(pool_for_get_connection):
|
||||
pool = pool_for_get_connection
|
||||
pool._pool = []
|
||||
pool._in_use = pool._max_size
|
||||
with pytest.raises(RuntimeError, match="exhausted"):
|
||||
pool.get_connection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool_for_return_connection(iris_module):
|
||||
cfg = _config(iris_module)
|
||||
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
|
||||
pool = iris_module.IrisConnectionPool(cfg)
|
||||
return pool
|
||||
|
||||
|
||||
def test_return_connection_adds_healthy(pool_for_return_connection):
|
||||
pool = pool_for_return_connection
|
||||
pool._in_use = 1
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
pool.return_connection(conn)
|
||||
assert pool._pool[-1] is conn
|
||||
assert pool._in_use == 0
|
||||
|
||||
|
||||
def test_return_connection_replaces_bad(pool_for_return_connection):
|
||||
pool = pool_for_return_connection
|
||||
pool._in_use = 1
|
||||
bad_conn = MagicMock()
|
||||
bad_cursor = MagicMock()
|
||||
bad_cursor.execute.side_effect = OSError("bad")
|
||||
bad_conn.cursor.return_value = bad_cursor
|
||||
replacement = MagicMock()
|
||||
pool._create_connection = MagicMock(return_value=replacement)
|
||||
pool.return_connection(bad_conn)
|
||||
bad_conn.close.assert_called_once()
|
||||
assert pool._pool[-1] is replacement
|
||||
assert pool._in_use == 0
|
||||
|
||||
|
||||
def test_return_connection_ignores_none(pool_for_return_connection):
|
||||
pool = pool_for_return_connection
|
||||
before = len(pool._pool)
|
||||
pool.return_connection(None)
|
||||
assert len(pool._pool) == before
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool_for_schema_and_close(iris_module):
|
||||
cfg = _config(iris_module)
|
||||
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
|
||||
pool = iris_module.IrisConnectionPool(cfg)
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
pool._pool = [conn]
|
||||
return pool, conn, cursor
|
||||
|
||||
|
||||
def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close):
|
||||
pool, conn, cursor = pool_for_schema_and_close
|
||||
pool._schemas_initialized = {"cached_schema"}
|
||||
pool.ensure_schema_exists("cached_schema")
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_schema_exists_creates_new(pool_for_schema_and_close):
|
||||
pool, conn, cursor = pool_for_schema_and_close
|
||||
pool._schemas_initialized = set()
|
||||
cursor.fetchone.return_value = (0,)
|
||||
pool.ensure_schema_exists("new_schema")
|
||||
assert "new_schema" in pool._schemas_initialized
|
||||
assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close):
|
||||
pool, conn, cursor = pool_for_schema_and_close
|
||||
pool._schemas_initialized = set()
|
||||
cursor.fetchone.return_value = (1,)
|
||||
pool.ensure_schema_exists("existing_schema")
|
||||
conn.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close):
|
||||
pool, conn, cursor = pool_for_schema_and_close
|
||||
pool._schemas_initialized = set()
|
||||
cursor.execute.side_effect = RuntimeError("schema failure")
|
||||
with pytest.raises(RuntimeError, match="schema failure"):
|
||||
pool.ensure_schema_exists("broken_schema")
|
||||
conn.rollback.assert_called()
|
||||
|
||||
|
||||
def test_close_all_closes_and_resets(iris_module):
|
||||
cfg = _config(iris_module)
|
||||
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
|
||||
pool = iris_module.IrisConnectionPool(cfg)
|
||||
conn = MagicMock()
|
||||
conn_2 = MagicMock()
|
||||
conn_2.close.side_effect = OSError("close fail")
|
||||
pool._pool = [conn, conn_2]
|
||||
pool._schemas_initialized = {"x"}
|
||||
pool.close_all()
|
||||
assert pool._pool == []
|
||||
assert pool._in_use == 0
|
||||
assert pool._schemas_initialized == set()
|
||||
|
||||
|
||||
def test_iris_vector_init_get_cursor_and_create(iris_module):
|
||||
pool = MagicMock()
|
||||
pool.get_connection.return_value = MagicMock()
|
||||
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=pool):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module))
|
||||
|
||||
assert vector.table_name == "EMBEDDING_COLLECTION"
|
||||
assert vector.schema == "schema"
|
||||
assert vector.get_type() == iris_module.VectorType.IRIS
|
||||
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
vector.pool.get_connection.return_value = conn
|
||||
|
||||
with vector._get_cursor() as got_cursor:
|
||||
assert got_cursor is cursor
|
||||
conn.commit.assert_called_once()
|
||||
vector.pool.return_connection.assert_called_with(conn)
|
||||
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
vector.pool.get_connection.return_value = conn
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
with vector._get_cursor():
|
||||
raise RuntimeError("boom")
|
||||
conn.rollback.assert_called_once()
|
||||
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock(return_value=["id-1"])
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
|
||||
assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"]
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module))
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id")
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
SimpleNamespace(page_content="b", metadata=None),
|
||||
]
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
assert ids == ["id-1", "generated-id"]
|
||||
assert cursor.execute.call_count == 2
|
||||
|
||||
cursor.fetchone.return_value = (1,)
|
||||
assert vector.text_exists("id-1") is True
|
||||
cursor.fetchone.return_value = None
|
||||
assert vector.text_exists("id-2") is False
|
||||
|
||||
vector._get_cursor = MagicMock(side_effect=RuntimeError("db down"))
|
||||
assert vector.text_exists("id-3") is False
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector.delete_by_ids([])
|
||||
before = cursor.execute.call_count
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
assert cursor.execute.call_count == before + 1
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
assert "meta LIKE" in cursor.execute.call_args.args[0]
|
||||
|
||||
cursor.fetchall.return_value = [
|
||||
("id-1", "text-1", '{"document_id":"d-1"}', 0.9),
|
||||
("id-2", "text-2", '{"document_id":"d-2"}', 0.2),
|
||||
("id-x",),
|
||||
]
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
|
||||
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
|
||||
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", cfg)
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
cursor.execute.side_effect = None
|
||||
cursor.fetchall.return_value = [
|
||||
("id-1", "text-1", '{"document_id":"d-1"}', 0.7),
|
||||
("id-2", "text-2", "{}", None),
|
||||
]
|
||||
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.7)
|
||||
assert docs[1].metadata["score"] == pytest.approx(0.0)
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = [RuntimeError("rank failed"), None]
|
||||
cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)]
|
||||
docs = vector.search_by_full_text("query", top_k=1)
|
||||
assert len(docs) == 1
|
||||
assert cursor.execute.call_count == 2
|
||||
|
||||
cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False)
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector_like = iris_module.IrisVector("collection", cfg_like)
|
||||
vector_like._get_cursor = _cursor_ctx
|
||||
|
||||
fake_libs = types.ModuleType("libs")
|
||||
fake_helper = types.ModuleType("libs.helper")
|
||||
fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%")
|
||||
monkeypatch.setitem(sys.modules, "libs", fake_libs)
|
||||
monkeypatch.setitem(sys.modules, "libs.helper", fake_helper)
|
||||
|
||||
cursor.reset_mock()
|
||||
cursor.execute.side_effect = None
|
||||
cursor.fetchall.return_value = []
|
||||
assert vector_like.search_by_full_text("100%", top_k=1) == []
|
||||
|
||||
|
||||
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector.delete()
|
||||
assert "DROP TABLE" in cursor.execute.call_args.args[0]
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(iris_module.redis_client, "set", MagicMock())
|
||||
|
||||
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(2)
|
||||
cursor.execute.assert_called_once()
|
||||
|
||||
cursor.reset_mock()
|
||||
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector.pool.ensure_schema_exists = MagicMock()
|
||||
vector._create_collection(3)
|
||||
assert cursor.execute.call_count == 3
|
||||
iris_module.redis_client.set.assert_called_once()
|
||||
|
||||
cursor.reset_mock()
|
||||
vector.config.IRIS_TEXT_INDEX = False
|
||||
vector._create_collection(3)
|
||||
assert cursor.execute.call_count == 2
|
||||
|
||||
factory = iris_module.IrisVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972)
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url")
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1)
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2)
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True)
|
||||
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en")
|
||||
|
||||
with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,394 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_opensearch_modules():
|
||||
opensearchpy = types.ModuleType("opensearchpy")
|
||||
opensearch_helpers = types.ModuleType("opensearchpy.helpers")
|
||||
|
||||
class BulkIndexError(Exception):
|
||||
def __init__(self, errors):
|
||||
super().__init__("bulk error")
|
||||
self.errors = errors
|
||||
|
||||
class OpenSearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.indices = SimpleNamespace(
|
||||
refresh=MagicMock(),
|
||||
exists=MagicMock(return_value=False),
|
||||
delete=MagicMock(),
|
||||
create=MagicMock(),
|
||||
)
|
||||
self.bulk = MagicMock(return_value={"errors": False, "items": []})
|
||||
self.search = MagicMock(return_value={"hits": {"hits": []}})
|
||||
self.delete_by_query = MagicMock()
|
||||
self.get = MagicMock(return_value={"_id": "id"})
|
||||
self.exists = MagicMock(return_value=True)
|
||||
|
||||
opensearch_helpers.BulkIndexError = BulkIndexError
|
||||
opensearch_helpers.bulk = MagicMock()
|
||||
|
||||
opensearchpy.OpenSearch = OpenSearch
|
||||
opensearchpy.helpers = opensearch_helpers
|
||||
|
||||
return {
|
||||
"opensearchpy": opensearchpy,
|
||||
"opensearchpy.helpers": opensearch_helpers,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lindorm_module(monkeypatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.lindorm.lindorm_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.LindormVectorStoreConfig(
|
||||
hosts="http://localhost:9200",
|
||||
username="user",
|
||||
password="pass",
|
||||
using_ugc=False,
|
||||
request_timeout=3.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("hosts", None, "config URL is required"),
|
||||
("username", None, "config USERNAME is required"),
|
||||
("password", None, "config PASSWORD is required"),
|
||||
],
|
||||
)
|
||||
def test_lindorm_config_validation(lindorm_module, field, value, message):
|
||||
values = _config(lindorm_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
lindorm_module.LindormVectorStoreConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_to_opensearch_params_and_init(lindorm_module):
|
||||
cfg = _config(lindorm_module)
|
||||
params = cfg.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == "http://localhost:9200"
|
||||
assert params["http_auth"] == ("user", "pass")
|
||||
|
||||
vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False)
|
||||
assert vector._collection_name == "collection"
|
||||
assert vector.get_type() == lindorm_module.VectorType.LINDORM
|
||||
|
||||
with pytest.raises(ValueError, match="routing_value"):
|
||||
lindorm_module.LindormVectorStore("c", cfg, using_ugc=True)
|
||||
|
||||
vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE")
|
||||
assert vector_ugc._routing == "route"
|
||||
|
||||
|
||||
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
|
||||
vector.create(docs, [[0.1]])
|
||||
vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}])
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1]])
|
||||
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock())
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
Document(page_content="c", metadata={"doc_id": "id-3"}),
|
||||
]
|
||||
embeddings = [[0.1], [0.2], [0.3]]
|
||||
|
||||
vector.add_texts(docs, embeddings, batch_size=2, timeout=9)
|
||||
|
||||
assert vector._client.bulk.call_count == 2
|
||||
actions = vector._client.bulk.call_args_list[0].args[0]
|
||||
assert actions[0]["index"]["routing"] == "route"
|
||||
assert actions[1][lindorm_module.ROUTING_FIELD] == "route"
|
||||
vector.refresh()
|
||||
vector._client.indices.refresh.assert_called_once_with(index="collection")
|
||||
|
||||
|
||||
def test_add_texts_error_paths(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
|
||||
vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]}
|
||||
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
|
||||
with pytest.raises(Exception, match="RetryError"):
|
||||
vector.add_texts(docs, [[0.1]], batch_size=1)
|
||||
|
||||
vector._client.bulk.side_effect = RuntimeError("bulk failed")
|
||||
with pytest.raises(Exception, match="RetryError"):
|
||||
vector.add_texts(docs, [[0.1]], batch_size=1)
|
||||
|
||||
|
||||
def test_metadata_lookup_and_delete_by_metadata(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
|
||||
|
||||
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
|
||||
assert ids == ["id-1", "id-2"]
|
||||
query = vector._client.search.call_args.kwargs["body"]
|
||||
must_conditions = query["query"]["bool"]["must"]
|
||||
assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions)
|
||||
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"])
|
||||
|
||||
vector._client.search.return_value = {"hits": {"hits": []}}
|
||||
vector.delete_by_ids.reset_mock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-2")
|
||||
vector.delete_by_ids.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_by_ids_paths(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.indices.exists.assert_not_called()
|
||||
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.delete_by_ids(["id-1"])
|
||||
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector._client.exists.side_effect = [True, False]
|
||||
lindorm_module.helpers.bulk.reset_mock()
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
lindorm_module.helpers.bulk.assert_called_once()
|
||||
actions = lindorm_module.helpers.bulk.call_args.args[1]
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["routing"] == "route"
|
||||
|
||||
lindorm_module.helpers.bulk.reset_mock()
|
||||
lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError(
|
||||
errors=[
|
||||
{"delete": {"status": 404, "_id": "id-404"}},
|
||||
{"delete": {"status": 500, "_id": "id-500"}},
|
||||
]
|
||||
)
|
||||
vector._client.exists.side_effect = [True]
|
||||
vector.delete_by_ids(["id-1"])
|
||||
|
||||
|
||||
def test_delete_and_text_exists(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
vector.delete()
|
||||
vector._client.delete_by_query.assert_called_once()
|
||||
vector._client.indices.refresh.assert_called_once_with(index="collection")
|
||||
|
||||
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector.delete()
|
||||
vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60})
|
||||
|
||||
vector._client.indices.delete.reset_mock()
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.delete()
|
||||
vector._client.indices.delete.assert_not_called()
|
||||
|
||||
assert vector.text_exists("id-1") is True
|
||||
vector._client.get.side_effect = RuntimeError("missing")
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
|
||||
def test_search_by_vector_validation_and_success(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a list"):
|
||||
vector.search_by_vector("bad")
|
||||
|
||||
with pytest.raises(ValueError, match="should be floats"):
|
||||
vector.search_by_vector([0.1, "bad"])
|
||||
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
lindorm_module.Field.CONTENT_KEY: "doc-a",
|
||||
lindorm_module.Field.VECTOR: [0.1],
|
||||
lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"_score": 0.2,
|
||||
"_source": {
|
||||
lindorm_module.Field.CONTENT_KEY: "doc-b",
|
||||
lindorm_module.Field.VECTOR: [0.2],
|
||||
lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
call_kwargs = vector._client.search.call_args.kwargs
|
||||
query = call_kwargs["body"]
|
||||
assert "ext" in query
|
||||
assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"]
|
||||
assert call_kwargs["params"]["routing"] == "route"
|
||||
|
||||
vector._client.search.side_effect = RuntimeError("search failed")
|
||||
with pytest.raises(RuntimeError, match="search failed"):
|
||||
vector.search_by_vector([0.1])
|
||||
|
||||
|
||||
def test_search_by_full_text_success_and_error(lindorm_module):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
lindorm_module.Field.CONTENT_KEY: "doc-a",
|
||||
lindorm_module.Field.VECTOR: [0.1],
|
||||
lindorm_module.Field.METADATA_KEY: {"doc_id": "1"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc-a"
|
||||
|
||||
query = vector._client.search.call_args.kwargs["body"]
|
||||
assert query["query"]["bool"]["filter"]
|
||||
|
||||
vector._client.search.side_effect = RuntimeError("full text failed")
|
||||
with pytest.raises(RuntimeError, match="full text failed"):
|
||||
vector.search_by_full_text("hello")
|
||||
|
||||
|
||||
def test_create_collection_paths(lindorm_module, monkeypatch):
|
||||
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
vector.create_collection([])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock())
|
||||
|
||||
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection([[0.1, 0.2]])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"})
|
||||
vector._client.indices.create.assert_called_once()
|
||||
body = vector._client.indices.create.call_args.kwargs["body"]
|
||||
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf"
|
||||
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine"
|
||||
|
||||
vector._client.indices.create.reset_mock()
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector.create_collection([[0.1, 0.2]])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
|
||||
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
|
||||
factory = lindorm_module.LindormVectorStoreFactory()
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user")
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass")
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0)
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw")
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2")
|
||||
monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={})
|
||||
embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3])
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None)
|
||||
with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"):
|
||||
factory.init_vector(dataset, attributes=[], embeddings=embeddings)
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
|
||||
|
||||
dataset_existing_plain = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct="{}",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False},
|
||||
)
|
||||
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
|
||||
result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings)
|
||||
assert result == "vector"
|
||||
assert store_cls.call_args.args[0] == "existing"
|
||||
|
||||
dataset_existing_ugc = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct="{}",
|
||||
index_struct_dict={
|
||||
"vector_store": {"class_prefix": "ROUTING"},
|
||||
"using_ugc": True,
|
||||
"dimension": 1536,
|
||||
"index_type": "hnsw",
|
||||
"distance_type": "l2",
|
||||
},
|
||||
)
|
||||
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
|
||||
factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings)
|
||||
assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2"
|
||||
assert store_cls.call_args.kwargs["routing_value"] == "ROUTING"
|
||||
|
||||
dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={})
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True)
|
||||
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
|
||||
factory.init_vector(dataset_new, attributes=[], embeddings=embeddings)
|
||||
assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2"
|
||||
assert store_cls.call_args.kwargs["routing_value"] == "auto_collection"
|
||||
assert dataset_new.index_struct is not None
|
||||
|
||||
dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={})
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
|
||||
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
|
||||
factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings)
|
||||
assert store_cls.call_args.args[0] == "auto_collection"
|
||||
assert store_cls.call_args.kwargs["routing_value"] is None
|
||||
@ -0,0 +1,252 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_mo_vector_modules():
|
||||
mo_vector = types.ModuleType("mo_vector")
|
||||
mo_vector.__path__ = []
|
||||
mo_vector_client = types.ModuleType("mo_vector.client")
|
||||
|
||||
class MoVectorClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.create_full_text_index = MagicMock()
|
||||
self.insert = MagicMock()
|
||||
self.get = MagicMock(return_value=[])
|
||||
self.delete = MagicMock()
|
||||
self.query_by_metadata = MagicMock(return_value=[])
|
||||
self.query = MagicMock(return_value=[])
|
||||
self.full_text_query = MagicMock(return_value=[])
|
||||
|
||||
mo_vector_client.MoVectorClient = MoVectorClient
|
||||
mo_vector.client = mo_vector_client
|
||||
return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def matrixone_module(monkeypatch):
|
||||
for name, module in _build_fake_mo_vector_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.matrixone.matrixone_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _valid_config(module):
|
||||
return module.MatrixoneConfig(
|
||||
host="localhost",
|
||||
port=6001,
|
||||
user="dump",
|
||||
password="111",
|
||||
database="dify",
|
||||
metric="l2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config host is required"),
|
||||
("port", 0, "config port is required"),
|
||||
("user", "", "config user is required"),
|
||||
("password", "", "config password is required"),
|
||||
("database", "", "config database is required"),
|
||||
],
|
||||
)
|
||||
def test_matrixone_config_validation(matrixone_module, field, value, message):
|
||||
values = _valid_config(matrixone_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
matrixone_module.MatrixoneConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
|
||||
client = vector._get_client(dimension=3, create_table=True)
|
||||
|
||||
assert client.kwargs["table_name"] == "collection_1"
|
||||
client.create_full_text_index.assert_called_once()
|
||||
matrixone_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
|
||||
client = vector._get_client(dimension=3, create_table=True)
|
||||
|
||||
client.create_full_text_index.assert_not_called()
|
||||
matrixone_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = None
|
||||
fake_client = MagicMock()
|
||||
fake_client.get.return_value = [{"id": "seg-1"}]
|
||||
vector._get_client = MagicMock(return_value=fake_client)
|
||||
|
||||
exists = vector.text_exists("seg-1")
|
||||
|
||||
assert exists is True
|
||||
vector._get_client.assert_called_once_with(None, False)
|
||||
|
||||
|
||||
def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
vector.client.full_text_query.return_value = [
|
||||
SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1),
|
||||
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7),
|
||||
]
|
||||
|
||||
docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc-a"
|
||||
assert docs[0].metadata["doc_id"] == "1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}}
|
||||
|
||||
|
||||
def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
fake_client = MagicMock()
|
||||
vector._get_client = MagicMock(return_value=fake_client)
|
||||
vector.add_texts = MagicMock(return_value=["seg-1"])
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
result = vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector.get_type() == "matrixone"
|
||||
assert result == ["seg-1"]
|
||||
vector._get_client.assert_called_once_with(2, True)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
|
||||
|
||||
failing_client = MagicMock()
|
||||
failing_client.create_full_text_index.side_effect = RuntimeError("boom")
|
||||
monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client))
|
||||
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
client = vector._get_client(dimension=3, create_table=True)
|
||||
|
||||
assert client is failing_client
|
||||
matrixone_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
|
||||
Document(page_content="b", metadata={"document_id": "d-2"}),
|
||||
SimpleNamespace(page_content="c", metadata=None),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
|
||||
# For current prod code, only docs with metadata get ids, so only two ids
|
||||
assert ids == ["doc-a", "generated-uuid"]
|
||||
vector.client.insert.assert_called_once()
|
||||
insert_kwargs = vector.client.insert.call_args.kwargs
|
||||
# All lists passed to insert should be the same length
|
||||
texts = insert_kwargs["texts"]
|
||||
embeddings = insert_kwargs["embeddings"]
|
||||
metadatas = insert_kwargs["metadatas"]
|
||||
ids_insert = insert_kwargs["ids"]
|
||||
assert len(texts) == len(embeddings) == len(metadatas) == len(docs)
|
||||
# ids may be shorter than docs for current prod code, but should match number of docs with metadata
|
||||
assert ids_insert == ["doc-a", "generated-uuid"]
|
||||
|
||||
|
||||
def test_delete_and_metadata_methods(matrixone_module):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")]
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector.client.delete.assert_not_called()
|
||||
|
||||
vector.delete_by_ids(["seg-1"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete()
|
||||
|
||||
assert ids == ["seg-1", "seg-2"]
|
||||
assert vector.client.delete.call_count == 3
|
||||
|
||||
|
||||
def test_search_by_vector_builds_documents(matrixone_module):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
vector.client.query.return_value = [
|
||||
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}),
|
||||
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}),
|
||||
]
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"])
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "doc-a"
|
||||
assert docs[1].metadata["doc_id"] == "2"
|
||||
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
|
||||
|
||||
|
||||
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
|
||||
factory = matrixone_module.MatrixoneVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1")
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001)
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump")
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111")
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify")
|
||||
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2")
|
||||
|
||||
with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -1,18 +1,414 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_default_value():
|
||||
def _build_fake_pymilvus_modules():
|
||||
pymilvus = types.ModuleType("pymilvus")
|
||||
pymilvus.__path__ = []
|
||||
pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client")
|
||||
pymilvus_orm = types.ModuleType("pymilvus.orm")
|
||||
pymilvus_orm.__path__ = []
|
||||
pymilvus_orm_types = types.ModuleType("pymilvus.orm.types")
|
||||
|
||||
class MilvusError(Exception):
|
||||
pass
|
||||
|
||||
class MilvusClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.init_kwargs = kwargs
|
||||
self.has_collection = MagicMock(return_value=False)
|
||||
self.describe_collection = MagicMock(
|
||||
return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]}
|
||||
)
|
||||
self.get_server_version = MagicMock(return_value="2.5.0")
|
||||
self.insert = MagicMock(return_value=[1])
|
||||
self.query = MagicMock(return_value=[])
|
||||
self.delete = MagicMock()
|
||||
self.drop_collection = MagicMock()
|
||||
self.search = MagicMock(return_value=[[]])
|
||||
self.create_collection = MagicMock()
|
||||
|
||||
class IndexParams:
|
||||
def __init__(self):
|
||||
self.indexes = []
|
||||
|
||||
def add_index(self, **kwargs):
|
||||
self.indexes.append(kwargs)
|
||||
|
||||
class DataType:
|
||||
JSON = "JSON"
|
||||
VARCHAR = "VARCHAR"
|
||||
INT64 = "INT64"
|
||||
SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR"
|
||||
FLOAT_VECTOR = "FLOAT_VECTOR"
|
||||
|
||||
class FieldSchema:
|
||||
def __init__(self, name, dtype, **kwargs):
|
||||
self.name = name
|
||||
self.dtype = dtype
|
||||
self.kwargs = kwargs
|
||||
|
||||
class CollectionSchema:
|
||||
def __init__(self, fields):
|
||||
self.fields = fields
|
||||
self.functions = []
|
||||
|
||||
def add_function(self, func):
|
||||
self.functions.append(func)
|
||||
|
||||
class FunctionType:
|
||||
BM25 = "BM25"
|
||||
|
||||
class Function:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def infer_dtype_bydata(_value):
|
||||
return DataType.FLOAT_VECTOR
|
||||
|
||||
pymilvus.MilvusException = MilvusError
|
||||
pymilvus.MilvusClient = MilvusClient
|
||||
pymilvus.IndexParams = IndexParams
|
||||
pymilvus.CollectionSchema = CollectionSchema
|
||||
pymilvus.DataType = DataType
|
||||
pymilvus.FieldSchema = FieldSchema
|
||||
pymilvus.Function = Function
|
||||
pymilvus.FunctionType = FunctionType
|
||||
pymilvus_milvus_client.IndexParams = IndexParams
|
||||
pymilvus_orm.types = pymilvus_orm_types
|
||||
pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata
|
||||
|
||||
# Attach submodules for dotted imports
|
||||
pymilvus.milvus_client = pymilvus_milvus_client
|
||||
pymilvus.orm = pymilvus_orm
|
||||
|
||||
return {
|
||||
"pymilvus": pymilvus,
|
||||
"pymilvus.milvus_client": pymilvus_milvus_client,
|
||||
"pymilvus.orm": pymilvus_orm,
|
||||
"pymilvus.orm.types": pymilvus_orm_types,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def milvus_module(monkeypatch):
|
||||
for name, module in _build_fake_pymilvus_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.milvus.milvus_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"uri": "http://localhost:19530",
|
||||
"user": "root",
|
||||
"password": "Milvus",
|
||||
"database": "default",
|
||||
"enable_hybrid_search": False,
|
||||
"analyzer_params": None,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.MilvusConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_config_validation_and_defaults(milvus_module):
|
||||
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
|
||||
|
||||
for key in valid_config:
|
||||
config = valid_config.copy()
|
||||
del config[key]
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig.model_validate(config)
|
||||
milvus_module.MilvusConfig.model_validate(config)
|
||||
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
|
||||
|
||||
config = MilvusConfig.model_validate(valid_config)
|
||||
config = milvus_module.MilvusConfig.model_validate(valid_config)
|
||||
assert config.database == "default"
|
||||
|
||||
token_config = milvus_module.MilvusConfig.model_validate(
|
||||
{"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"}
|
||||
)
|
||||
assert token_config.token == "token-value"
|
||||
|
||||
|
||||
def test_config_to_milvus_params(milvus_module):
|
||||
config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
|
||||
|
||||
params = config.to_milvus_params()
|
||||
|
||||
assert params["uri"] == "http://localhost:19530"
|
||||
assert params["db_name"] == "default"
|
||||
assert params["analyzer_params"] == '{"tokenizer":"standard"}'
|
||||
|
||||
|
||||
def test_init_client_supports_token_and_user_password(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
token_client = vector._init_client(
|
||||
milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"})
|
||||
)
|
||||
assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"}
|
||||
|
||||
user_client = vector._init_client(_config(milvus_module))
|
||||
assert user_client.init_kwargs["uri"] == "http://localhost:19530"
|
||||
assert user_client.init_kwargs["user"] == "root"
|
||||
assert user_client.init_kwargs["password"] == "Milvus"
|
||||
|
||||
|
||||
def test_init_loads_fields_when_collection_exists(milvus_module):
|
||||
client = milvus_module.MilvusClient(uri="http://localhost:19530")
|
||||
client.has_collection.return_value = True
|
||||
client.describe_collection.return_value = {
|
||||
"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}]
|
||||
}
|
||||
|
||||
with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client):
|
||||
with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False):
|
||||
vector = milvus_module.MilvusVector("collection_1", _config(milvus_module))
|
||||
|
||||
assert "id" not in vector._fields
|
||||
assert "content" in vector._fields
|
||||
|
||||
|
||||
def test_load_collection_fields_from_argument_and_remote(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._client = MagicMock()
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]}
|
||||
|
||||
vector._load_collection_fields(["id", "metadata"])
|
||||
assert vector._fields == ["metadata"]
|
||||
|
||||
vector._load_collection_fields()
|
||||
assert vector._fields == ["content"]
|
||||
|
||||
|
||||
def test_check_hybrid_search_support_branches(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._client_config = SimpleNamespace(enable_hybrid_search=False)
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
vector._client_config = SimpleNamespace(enable_hybrid_search=True)
|
||||
vector._client.get_server_version.return_value = "Zilliz Cloud 2.4"
|
||||
assert vector._check_hybrid_search_support() is True
|
||||
|
||||
vector._client.get_server_version.return_value = "2.5.1"
|
||||
assert vector._check_hybrid_search_support() is True
|
||||
|
||||
vector._client.get_server_version.return_value = "2.4.9"
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
vector._client.get_server_version.side_effect = RuntimeError("boom")
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
|
||||
def test_get_type_and_create_delegate(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [SimpleNamespace(page_content="hello", metadata=None)]
|
||||
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector.get_type() == "milvus"
|
||||
vector.create_collection.assert_called_once()
|
||||
create_args = vector.create_collection.call_args.args
|
||||
assert create_args[0] == [[0.1, 0.2]]
|
||||
assert create_args[1] == [{}]
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_batches_and_raises_milvus_exception(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.insert.side_effect = [["id-1"], ["id-2"]]
|
||||
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)]
|
||||
embeddings = [[0.1, 0.2] for _ in range(1001)]
|
||||
|
||||
ids = vector.add_texts(docs, embeddings)
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert vector._client.insert.call_count == 2
|
||||
|
||||
vector._client.insert.side_effect = milvus_module.MilvusException("insert failed")
|
||||
with pytest.raises(milvus_module.MilvusException):
|
||||
vector.add_texts([Document(page_content="x", metadata={})], [[0.1]])
|
||||
|
||||
|
||||
def test_get_ids_and_delete_methods(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.query.return_value = [{"id": 1}, {"id": 2}]
|
||||
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2]
|
||||
vector._client.query.return_value = []
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
|
||||
|
||||
vector._client.has_collection.return_value = True
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102])
|
||||
|
||||
vector._client.delete.reset_mock()
|
||||
vector._client.query.return_value = [{"id": 11}, {"id": 12}]
|
||||
vector.delete_by_ids(["doc-a", "doc-b"])
|
||||
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12])
|
||||
|
||||
vector._client.has_collection.return_value = True
|
||||
vector.delete()
|
||||
vector._client.drop_collection.assert_called_once_with("collection_1", None)
|
||||
|
||||
|
||||
def test_text_exists_and_field_exists(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._fields = ["content", "metadata"]
|
||||
vector._client = MagicMock()
|
||||
vector._client.has_collection.return_value = False
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
vector._client.has_collection.return_value = True
|
||||
vector._client.query.return_value = [{"id": 1}]
|
||||
assert vector.text_exists("doc-1") is True
|
||||
vector._client.query.return_value = []
|
||||
assert vector.text_exists("doc-1") is False
|
||||
assert vector.field_exists("content") is True
|
||||
assert vector.field_exists("unknown") is False
|
||||
|
||||
|
||||
def test_process_search_results_and_search_methods(milvus_module):
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._fields = ["content", "metadata", "sparse_vector"]
|
||||
|
||||
processed = vector._process_search_results(
|
||||
[
|
||||
[
|
||||
{"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9},
|
||||
{"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2},
|
||||
]
|
||||
],
|
||||
[milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY],
|
||||
score_threshold=0.5,
|
||||
)
|
||||
assert len(processed) == 1
|
||||
assert processed[0].metadata["score"] == 0.9
|
||||
|
||||
vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]]
|
||||
vector._process_search_results = MagicMock(return_value=["doc"])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1)
|
||||
assert docs == ["doc"]
|
||||
assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]'
|
||||
|
||||
vector._hybrid_search_enabled = False
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
vector._hybrid_search_enabled = True
|
||||
vector._fields = []
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
vector._fields = [milvus_module.Field.SPARSE_VECTOR]
|
||||
vector._process_search_results = MagicMock(return_value=["full-text-doc"])
|
||||
full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2)
|
||||
assert full_text_docs == ["full-text-doc"]
|
||||
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._consistency_level = "Session"
|
||||
vector._client_config = _config(milvus_module)
|
||||
vector._hybrid_search_enabled = False
|
||||
vector._client = MagicMock()
|
||||
|
||||
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.has_collection.return_value = True
|
||||
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
|
||||
milvus_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._consistency_level = "Session"
|
||||
vector._client = MagicMock()
|
||||
vector._client.has_collection.return_value = False
|
||||
vector._load_collection_fields = MagicMock()
|
||||
|
||||
vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
|
||||
vector._hybrid_search_enabled = True
|
||||
vector.create_collection(
|
||||
embeddings=[[0.1, 0.2]],
|
||||
metadatas=[{"doc_id": "1"}],
|
||||
index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}},
|
||||
)
|
||||
|
||||
call_kwargs = vector._client.create_collection.call_args.kwargs
|
||||
schema = call_kwargs["schema"]
|
||||
index_params_obj = call_kwargs["index_params"]
|
||||
field_names = [f.name for f in schema.fields]
|
||||
|
||||
assert milvus_module.Field.SPARSE_VECTOR in field_names
|
||||
assert len(schema.functions) == 1
|
||||
assert len(index_params_obj.indexes) == 2
|
||||
assert call_kwargs["consistency_level"] == "Session"
|
||||
|
||||
|
||||
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
|
||||
factory = milvus_module.MilvusVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default")
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True)
|
||||
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}')
|
||||
|
||||
with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
|
||||
@ -0,0 +1,230 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_clickhouse_connect_module():
|
||||
clickhouse_connect = types.ModuleType("clickhouse_connect")
|
||||
|
||||
class QueryResult:
|
||||
def __init__(self, rows=None, named_rows=None):
|
||||
self.row_count = len(rows or [])
|
||||
self.result_rows = rows or []
|
||||
self._named_rows = named_rows or []
|
||||
|
||||
def named_results(self):
|
||||
return self._named_rows
|
||||
|
||||
class Client:
|
||||
def __init__(self):
|
||||
self.command = MagicMock()
|
||||
self.query = MagicMock(return_value=QueryResult())
|
||||
|
||||
client = Client()
|
||||
|
||||
def get_client(**_kwargs):
|
||||
return client
|
||||
|
||||
clickhouse_connect.get_client = get_client
|
||||
clickhouse_connect.QueryResult = QueryResult
|
||||
clickhouse_connect._fake_client = client
|
||||
return clickhouse_connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myscale_module(monkeypatch):
|
||||
fake_module = _build_fake_clickhouse_connect_module()
|
||||
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
|
||||
|
||||
import core.rag.datasource.vdb.myscale.myscale_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.MyScaleConfig(
|
||||
host="localhost",
|
||||
port=8123,
|
||||
user="default",
|
||||
password="",
|
||||
database="dify",
|
||||
fts_params="",
|
||||
)
|
||||
|
||||
|
||||
def test_escape_str_replaces_backslash_and_quote(myscale_module):
|
||||
escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special")
|
||||
assert escaped == "text with special"
|
||||
|
||||
|
||||
def test_search_raises_for_invalid_top_k(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0)
|
||||
|
||||
|
||||
def test_search_builds_where_clause_for_cosine_threshold(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__(
|
||||
named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}]
|
||||
)
|
||||
|
||||
docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2)
|
||||
|
||||
assert len(docs) == 1
|
||||
sql = vector._client.query.call_args.args[0]
|
||||
assert "WHERE dist < 0.8" in sql
|
||||
|
||||
|
||||
def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._client.command.reset_mock()
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.command.assert_not_called()
|
||||
|
||||
|
||||
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
|
||||
factory = myscale_module.MyScaleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost")
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123)
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default")
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "")
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify")
|
||||
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "")
|
||||
|
||||
with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
|
||||
|
||||
def test_init_and_get_type_set_expected_defaults(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
|
||||
assert vector.get_type() == "myscale"
|
||||
assert vector._vec_order == myscale_module.SortOrder.ASC
|
||||
vector._client.command.assert_called_with("SET allow_experimental_object_type=1")
|
||||
|
||||
|
||||
def test_create_calls_create_collection_and_add_texts(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock(return_value=["seg-1"])
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
result = vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert result == ["seg-1"]
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_builds_expected_sql(myscale_module):
|
||||
config = myscale_module.MyScaleConfig(
|
||||
host="localhost",
|
||||
port=8123,
|
||||
user="default",
|
||||
password="",
|
||||
database="dify",
|
||||
fts_params="tokenizer=unicode",
|
||||
)
|
||||
vector = myscale_module.MyScaleVector("collection_1", config)
|
||||
vector._client.command.reset_mock()
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
assert vector._client.command.call_count == 2
|
||||
sql = vector._client.command.call_args_list[1].args[0]
|
||||
assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql
|
||||
assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql
|
||||
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
|
||||
|
||||
|
||||
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
docs = [
|
||||
Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
|
||||
Document(page_content="text-2", metadata={"document_id": "d-2"}),
|
||||
SimpleNamespace(page_content="text-3", metadata=None),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
|
||||
assert ids == ["doc-a", "generated-uuid"]
|
||||
sql = vector._client.command.call_args.args[0]
|
||||
assert "INSERT INTO dify.collection_1" in sql
|
||||
assert "te xt 1" in sql
|
||||
|
||||
|
||||
def test_text_exists_and_metadata_operations(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)])
|
||||
|
||||
assert vector.text_exists("id-1") is True
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
assert vector._client.command.call_count >= 2
|
||||
|
||||
|
||||
def test_search_delegation_methods(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._search = MagicMock(return_value=["result"])
|
||||
|
||||
result_vector = vector.search_by_vector([0.1, 0.2], top_k=2)
|
||||
result_text = vector.search_by_full_text("hello", top_k=2)
|
||||
|
||||
assert result_vector == ["result"]
|
||||
assert result_text == ["result"]
|
||||
assert vector._search.call_count == 2
|
||||
|
||||
|
||||
def test_search_with_document_filter_and_exception(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._client.query.return_value = SimpleNamespace(
|
||||
named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}]
|
||||
)
|
||||
|
||||
docs = vector._search(
|
||||
"distance(vector, [0.1])",
|
||||
myscale_module.SortOrder.ASC,
|
||||
top_k=2,
|
||||
document_ids_filter=["doc-1", "doc-2"],
|
||||
)
|
||||
assert len(docs) == 1
|
||||
sql = vector._client.query.call_args.args[0]
|
||||
assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql
|
||||
|
||||
vector._client.query.side_effect = RuntimeError("boom")
|
||||
assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == []
|
||||
|
||||
|
||||
def test_delete_drops_table(myscale_module):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
vector._client.command.reset_mock()
|
||||
|
||||
vector.delete()
|
||||
|
||||
vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1")
|
||||
@ -0,0 +1,553 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_pyobvector_module():
|
||||
pyobvector = types.ModuleType("pyobvector")
|
||||
|
||||
class VECTOR:
|
||||
def __init__(self, dim):
|
||||
self.dim = dim
|
||||
|
||||
def l2_distance(*_args, **_kwargs):
|
||||
return "l2"
|
||||
|
||||
def cosine_distance(*_args, **_kwargs):
|
||||
return "cosine"
|
||||
|
||||
def inner_product(*_args, **_kwargs):
|
||||
return "inner_product"
|
||||
|
||||
class ObVecClient:
|
||||
def __init__(self, **_kwargs):
|
||||
self.metadata_obj = SimpleNamespace(tables={})
|
||||
self.engine = MagicMock()
|
||||
self.check_table_exists = MagicMock(return_value=False)
|
||||
self.perform_raw_text_sql = MagicMock()
|
||||
self.prepare_index_params = MagicMock()
|
||||
self.create_table_with_index_params = MagicMock()
|
||||
self.refresh_metadata = MagicMock()
|
||||
self.insert = MagicMock()
|
||||
self.refresh_index = MagicMock()
|
||||
self.get = MagicMock()
|
||||
self.delete = MagicMock()
|
||||
self.set_ob_hnsw_ef_search = MagicMock()
|
||||
self.ann_search = MagicMock(return_value=[])
|
||||
self.drop_table_if_exist = MagicMock()
|
||||
|
||||
pyobvector.VECTOR = VECTOR
|
||||
pyobvector.ObVecClient = ObVecClient
|
||||
pyobvector.l2_distance = l2_distance
|
||||
pyobvector.cosine_distance = cosine_distance
|
||||
pyobvector.inner_product = inner_product
|
||||
return pyobvector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
|
||||
|
||||
import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.OceanBaseVectorConfig(
|
||||
host="127.0.0.1",
|
||||
port=2881,
|
||||
user="root",
|
||||
password="secret",
|
||||
database="test",
|
||||
enable_hybrid_search=True,
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config OCEANBASE_VECTOR_HOST is required"),
|
||||
("port", 0, "config OCEANBASE_VECTOR_PORT is required"),
|
||||
("user", "", "config OCEANBASE_VECTOR_USER is required"),
|
||||
("database", "", "config OCEANBASE_VECTOR_DATABASE is required"),
|
||||
],
|
||||
)
|
||||
def test_oceanbase_config_validation(oceanbase_module, field, value, message):
|
||||
values = _config(oceanbase_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
oceanbase_module.OceanBaseVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_rejects_invalid_collection_name(oceanbase_module):
|
||||
with pytest.raises(ValueError, match="Invalid collection name"):
|
||||
oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module))
|
||||
|
||||
|
||||
def test_distance_to_score_for_supported_metrics(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._config = SimpleNamespace(metric_type="l2")
|
||||
assert vector._distance_to_score(3.0) == pytest.approx(0.25)
|
||||
|
||||
vector._config = SimpleNamespace(metric_type="cosine")
|
||||
assert vector._distance_to_score(0.2) == pytest.approx(0.8)
|
||||
|
||||
vector._config = SimpleNamespace(metric_type="inner_product")
|
||||
assert vector._distance_to_score(-0.2) == pytest.approx(0.2)
|
||||
|
||||
|
||||
def test_get_distance_func_raises_for_unknown_metric(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._config = SimpleNamespace(metric_type="manhattan")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported metric_type"):
|
||||
vector._get_distance_func()
|
||||
|
||||
|
||||
def test_process_search_results_handles_json_and_score_threshold(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
rows = [
|
||||
("doc-1", '{"doc_id":"1"}', 0.9),
|
||||
("doc-2", "not-json", 0.8),
|
||||
("doc-3", {"doc_id": "3"}, 0.3),
|
||||
]
|
||||
|
||||
docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank")
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["doc_id"] == "1"
|
||||
assert docs[0].metadata["rank"] == 0.9
|
||||
assert docs[1].metadata["rank"] == 0.8
|
||||
|
||||
|
||||
def test_search_by_vector_validates_document_id_format(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._hnsw_ef_search = -1
|
||||
vector._config = SimpleNamespace(metric_type="cosine")
|
||||
vector._client = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid document ID format"):
|
||||
vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"])
|
||||
|
||||
|
||||
def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._hybrid_search_enabled = False
|
||||
vector._collection_name = "collection_1"
|
||||
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_check_hybrid_search_support_uses_version_comment(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._config = SimpleNamespace(enable_hybrid_search=True)
|
||||
vector._client = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",)
|
||||
vector._client.perform_raw_text_sql.return_value = cursor
|
||||
|
||||
assert vector._check_hybrid_search_support() is True
|
||||
|
||||
cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",)
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
|
||||
def test_init_get_type_and_field_loading(oceanbase_module):
|
||||
config = _config(oceanbase_module)
|
||||
config.enable_hybrid_search = False
|
||||
|
||||
table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")])
|
||||
fake_client = oceanbase_module.ObVecClient()
|
||||
fake_client.check_table_exists.return_value = True
|
||||
fake_client.metadata_obj.tables = {"collection_1": table}
|
||||
|
||||
with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client):
|
||||
vector = oceanbase_module.OceanBaseVector("collection_1", config)
|
||||
|
||||
assert vector.get_type() == "oceanbase"
|
||||
assert vector.field_exists("text") is True
|
||||
|
||||
|
||||
def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._fields = []
|
||||
vector._client = MagicMock()
|
||||
vector._client.metadata_obj.tables = {}
|
||||
|
||||
vector._load_collection_fields()
|
||||
assert vector._fields == []
|
||||
|
||||
vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))}
|
||||
vector._load_collection_fields()
|
||||
assert vector._fields == []
|
||||
|
||||
|
||||
def test_create_delegates_to_collection_and_insert(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="text", metadata={"doc_id": "1"})]
|
||||
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector._vec_dim == 2
|
||||
vector._create_collection.assert_called_once()
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._vec_dim = 2
|
||||
vector._hybrid_search_enabled = False
|
||||
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
|
||||
vector._client = MagicMock()
|
||||
vector.delete = MagicMock()
|
||||
vector._load_collection_fields = MagicMock()
|
||||
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection()
|
||||
vector._client.check_table_exists.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.check_table_exists.return_value = True
|
||||
vector._create_collection()
|
||||
vector.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
|
||||
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
|
||||
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
|
||||
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._vec_dim = 3
|
||||
vector._hybrid_search_enabled = True
|
||||
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
|
||||
vector._client = MagicMock()
|
||||
vector._client.check_table_exists.return_value = False
|
||||
vector._client.perform_raw_text_sql.side_effect = [
|
||||
[[None, None, None, None, None, None, "30"]],
|
||||
None,
|
||||
None,
|
||||
]
|
||||
index_params = MagicMock()
|
||||
vector._client.prepare_index_params.return_value = index_params
|
||||
vector.delete = MagicMock()
|
||||
vector._load_collection_fields = MagicMock()
|
||||
|
||||
vector._create_collection()
|
||||
|
||||
vector.delete.assert_called_once()
|
||||
vector._client.create_table_with_index_params.assert_called_once()
|
||||
index_params.add_index.assert_called_once()
|
||||
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
|
||||
oceanbase_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
|
||||
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
|
||||
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._vec_dim = 2
|
||||
vector._hybrid_search_enabled = True
|
||||
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
|
||||
vector._client = MagicMock()
|
||||
vector._client.check_table_exists.return_value = False
|
||||
vector._client.prepare_index_params.return_value = MagicMock()
|
||||
vector.delete = MagicMock()
|
||||
vector._load_collection_fields = MagicMock()
|
||||
|
||||
vector._client.perform_raw_text_sql.return_value = []
|
||||
with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"):
|
||||
vector._create_collection()
|
||||
|
||||
vector._client.perform_raw_text_sql.side_effect = [
|
||||
[[None, None, None, None, None, None, "0"]],
|
||||
RuntimeError("no privilege"),
|
||||
]
|
||||
with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"):
|
||||
vector._create_collection()
|
||||
|
||||
vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]]
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid")
|
||||
with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"):
|
||||
vector._create_collection()
|
||||
|
||||
|
||||
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
|
||||
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
|
||||
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
|
||||
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._vec_dim = 2
|
||||
vector._hybrid_search_enabled = True
|
||||
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
|
||||
vector._client = MagicMock()
|
||||
vector._client.check_table_exists.return_value = False
|
||||
vector._client.prepare_index_params.return_value = MagicMock()
|
||||
vector.delete = MagicMock()
|
||||
vector._load_collection_fields = MagicMock()
|
||||
|
||||
vector._client.perform_raw_text_sql.side_effect = [
|
||||
[[None, None, None, None, None, None, "30"]],
|
||||
RuntimeError("fulltext failed"),
|
||||
]
|
||||
with pytest.raises(Exception, match="Failed to add fulltext index"):
|
||||
vector._create_collection()
|
||||
|
||||
vector._hybrid_search_enabled = False
|
||||
vector._client.perform_raw_text_sql.side_effect = [
|
||||
[[None, None, None, None, None, None, "30"]],
|
||||
SQLAlchemyError("metadata index failed"),
|
||||
]
|
||||
vector._create_collection()
|
||||
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
|
||||
|
||||
|
||||
def test_check_hybrid_search_support_false_and_exception(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._config = SimpleNamespace(enable_hybrid_search=False)
|
||||
vector._client = MagicMock()
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
vector._config = SimpleNamespace(enable_hybrid_search=True)
|
||||
vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom")
|
||||
assert vector._check_hybrid_search_support() is False
|
||||
|
||||
|
||||
def test_add_texts_batches_refresh_and_exceptions(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2)
|
||||
vector._client = MagicMock()
|
||||
vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"])
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
Document(page_content="c", metadata={"doc_id": "id-3"}),
|
||||
]
|
||||
|
||||
vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
assert vector._client.insert.call_count == 2
|
||||
vector._client.refresh_index.assert_called_once()
|
||||
|
||||
vector._client.insert.reset_mock()
|
||||
vector._client.refresh_index.reset_mock()
|
||||
vector._client.insert.side_effect = RuntimeError("insert failed")
|
||||
with pytest.raises(Exception, match="Failed to insert batch"):
|
||||
vector.add_texts([docs[0]], [[0.1]])
|
||||
|
||||
vector._client.insert.side_effect = None
|
||||
vector._client.insert.return_value = None
|
||||
vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed")
|
||||
vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1)
|
||||
vector._get_uuids.return_value = ["id-1"]
|
||||
vector.add_texts([docs[0]], [[0.1]])
|
||||
|
||||
|
||||
def test_text_exists_and_delete_by_ids(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
vector._client.get.return_value = SimpleNamespace(rowcount=1)
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
vector._client.get.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(Exception, match="Failed to check text existence"):
|
||||
vector.text_exists("id-1")
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.delete.assert_not_called()
|
||||
|
||||
vector._client.delete.side_effect = None
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector._client.delete.assert_called_once()
|
||||
|
||||
vector._client.delete.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(Exception, match="Failed to delete documents"):
|
||||
vector.delete_by_ids(["id-1"])
|
||||
|
||||
|
||||
def test_get_ids_and_delete_by_metadata_field(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
execute_result = [("id-1",), ("id-2",)]
|
||||
|
||||
conn = MagicMock()
|
||||
conn.__enter__.return_value = conn
|
||||
conn.__exit__.return_value = None
|
||||
conn.execute.return_value = execute_result
|
||||
vector._client.engine.connect.return_value = conn
|
||||
|
||||
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
|
||||
assert ids == ["id-1", "id-2"]
|
||||
|
||||
with pytest.raises(Exception, match="Failed to query documents by metadata field"):
|
||||
vector.get_ids_by_metadata_field("bad key!", "doc-1")
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-1"])
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
|
||||
vector.delete_by_ids.reset_mock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_ids.assert_not_called()
|
||||
|
||||
|
||||
def test_search_by_full_text_paths(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._hybrid_search_enabled = True
|
||||
vector.field_exists = MagicMock(return_value=False)
|
||||
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
vector.field_exists.return_value = True
|
||||
vector._client = MagicMock()
|
||||
conn = MagicMock()
|
||||
tx = MagicMock()
|
||||
tx.__enter__.return_value = tx
|
||||
tx.__exit__.return_value = None
|
||||
conn.begin.return_value = tx
|
||||
conn.__enter__.return_value = conn
|
||||
conn.__exit__.return_value = None
|
||||
conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)]
|
||||
vector._client.engine.connect.return_value = conn
|
||||
|
||||
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
with pytest.raises(Exception, match="Full-text search failed"):
|
||||
vector.search_by_full_text("query", top_k=0)
|
||||
|
||||
|
||||
def test_search_by_vector_paths(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._hnsw_ef_search = -1
|
||||
vector._config = SimpleNamespace(metric_type="cosine")
|
||||
vector._client = MagicMock()
|
||||
vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)]
|
||||
vector._process_search_results = MagicMock(return_value=["doc"])
|
||||
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
ef_search=10,
|
||||
top_k=3,
|
||||
score_threshold=0.1,
|
||||
document_ids_filter=["good_id"],
|
||||
)
|
||||
assert docs == ["doc"]
|
||||
vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid score_threshold parameter"):
|
||||
vector.search_by_vector([0.1], score_threshold="x")
|
||||
|
||||
vector._client.ann_search.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(Exception, match="Vector search failed"):
|
||||
vector.search_by_vector([0.1], score_threshold=0.1)
|
||||
|
||||
|
||||
def test_get_distance_func_and_distance_to_score_errors(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._config = SimpleNamespace(metric_type="cosine")
|
||||
assert vector._get_distance_func() is oceanbase_module.cosine_distance
|
||||
|
||||
vector._config = SimpleNamespace(metric_type="unknown")
|
||||
with pytest.raises(ValueError, match="Unsupported metric_type"):
|
||||
vector._distance_to_score(0.1)
|
||||
|
||||
|
||||
def test_delete_success_and_exception(oceanbase_module):
|
||||
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector.delete()
|
||||
vector._client.drop_table_if_exist.assert_called_once_with("collection_1")
|
||||
|
||||
vector._client.drop_table_if_exist.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(Exception, match="Failed to delete collection"):
|
||||
vector.delete()
|
||||
|
||||
|
||||
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
|
||||
factory = oceanbase_module.OceanBaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine")
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10)
|
||||
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000)
|
||||
|
||||
with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].args[0] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].args[0] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,400 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_psycopg2_modules():
|
||||
psycopg2 = types.ModuleType("psycopg2")
|
||||
psycopg2.__path__ = []
|
||||
psycopg2_extras = types.ModuleType("psycopg2.extras")
|
||||
psycopg2_pool = types.ModuleType("psycopg2.pool")
|
||||
|
||||
class SimpleConnectionPool:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.getconn = MagicMock()
|
||||
self.putconn = MagicMock()
|
||||
|
||||
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
|
||||
psycopg2_extras.execute_values = MagicMock()
|
||||
|
||||
psycopg2.pool = psycopg2_pool
|
||||
psycopg2.extras = psycopg2_extras
|
||||
return {
|
||||
"psycopg2": psycopg2,
|
||||
"psycopg2.pool": psycopg2_pool,
|
||||
"psycopg2.extras": psycopg2_extras,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opengauss_module(monkeypatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.opengauss.opengauss as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, *, enable_pq=False):
|
||||
return module.OpenGaussConfig(
|
||||
host="localhost",
|
||||
port=6600,
|
||||
user="postgres",
|
||||
password="password",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
enable_pq=enable_pq,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config OPENGAUSS_HOST is required"),
|
||||
("port", 0, "config OPENGAUSS_PORT is required"),
|
||||
("user", "", "config OPENGAUSS_USER is required"),
|
||||
("password", "", "config OPENGAUSS_PASSWORD is required"),
|
||||
("database", "", "config OPENGAUSS_DATABASE is required"),
|
||||
("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"),
|
||||
("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"),
|
||||
],
|
||||
)
|
||||
def test_opengauss_config_validation(opengauss_module, field, value, message):
|
||||
values = _config(opengauss_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
opengauss_module.OpenGaussConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module):
|
||||
values = _config(opengauss_module).model_dump()
|
||||
values["min_connection"] = 6
|
||||
values["max_connection"] = 5
|
||||
|
||||
with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"):
|
||||
opengauss_module.OpenGaussConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
|
||||
assert vector.table_name == "embedding_collection_1"
|
||||
assert vector.get_type() == "opengauss"
|
||||
assert vector.pool is pool
|
||||
|
||||
|
||||
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True))
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector._create_index(1536)
|
||||
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("enable_pq=on" in sql for sql in executed_sql)
|
||||
assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql)
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector._create_index(3072)
|
||||
|
||||
cursor.execute.assert_not_called()
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_search_by_vector_validates_top_k(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_vector([0.1, 0.2], top_k=0)
|
||||
|
||||
|
||||
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
vector._get_cursor = MagicMock()
|
||||
|
||||
vector.delete_by_ids([])
|
||||
|
||||
vector._get_cursor.assert_not_called()
|
||||
|
||||
|
||||
def test_get_cursor_closes_commits_and_returns_connection(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
pool = MagicMock()
|
||||
conn = MagicMock()
|
||||
cur = MagicMock()
|
||||
pool.getconn.return_value = conn
|
||||
conn.cursor.return_value = cur
|
||||
vector.pool = pool
|
||||
|
||||
with vector._get_cursor() as got_cur:
|
||||
assert got_cur is cur
|
||||
|
||||
cur.close.assert_called_once()
|
||||
conn.commit.assert_called_once()
|
||||
pool.putconn.assert_called_once_with(conn)
|
||||
|
||||
|
||||
def test_create_calls_collection_insert_and_index(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
vector._create_index = MagicMock()
|
||||
docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
vector._create_index.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
vector._get_cursor = MagicMock()
|
||||
|
||||
vector._create_index(1536)
|
||||
|
||||
vector._get_cursor.assert_not_called()
|
||||
opengauss_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector._create_index(1536)
|
||||
|
||||
sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
cursor = MagicMock()
|
||||
opengauss_module.psycopg2.extras.execute_values.reset_mock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
docs = [
|
||||
Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}),
|
||||
SimpleNamespace(page_content="text-2", metadata=None),
|
||||
]
|
||||
monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
|
||||
assert ids == ["seg-1"]
|
||||
opengauss_module.psycopg2.extras.execute_values.assert_called_once()
|
||||
|
||||
|
||||
def test_text_exists_and_get_by_ids(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.fetchone.return_value = ("seg-1",)
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
assert vector.text_exists("seg-1") is True
|
||||
docs = vector.get_by_ids(["seg-1", "seg-2"])
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "text-1"
|
||||
|
||||
|
||||
def test_delete_and_metadata_field_queries(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
vector.delete_by_ids(["seg-1", "seg-2"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete()
|
||||
|
||||
sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql)
|
||||
assert any("meta->>%s = %s" in query for query in sql)
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql)
|
||||
|
||||
|
||||
def test_search_by_vector_and_full_text(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter(
|
||||
[
|
||||
({"doc_id": "1"}, "text-1", 0.1),
|
||||
({"doc_id": "2"}, "text-2", 0.6),
|
||||
]
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)])
|
||||
full_docs = vector.search_by_full_text("hello world", top_k=2)
|
||||
assert len(full_docs) == 1
|
||||
assert full_docs[0].page_content == "full-text"
|
||||
|
||||
|
||||
def test_search_by_full_text_validates_top_k(opengauss_module):
|
||||
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_full_text("query", top_k=0)
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(1536)
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._create_collection(1536)
|
||||
cursor.execute.assert_called_once()
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
|
||||
factory = opengauss_module.OpenGaussFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost")
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600)
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres")
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password")
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify")
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1)
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5)
|
||||
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False)
|
||||
|
||||
with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,360 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_opensearch_modules():
|
||||
opensearchpy = types.ModuleType("opensearchpy")
|
||||
opensearchpy_helpers = types.ModuleType("opensearchpy.helpers")
|
||||
|
||||
class BulkIndexError(Exception):
|
||||
def __init__(self, errors):
|
||||
super().__init__("bulk error")
|
||||
self.errors = errors
|
||||
|
||||
class Urllib3AWSV4SignerAuth:
|
||||
def __init__(self, credentials, region, service):
|
||||
self.credentials = credentials
|
||||
self.region = region
|
||||
self.service = service
|
||||
|
||||
class Urllib3HttpConnection:
|
||||
pass
|
||||
|
||||
class _IndicesClient:
|
||||
def __init__(self):
|
||||
self.exists = MagicMock(return_value=False)
|
||||
self.create = MagicMock()
|
||||
self.delete = MagicMock()
|
||||
|
||||
class OpenSearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.indices = _IndicesClient()
|
||||
self.search = MagicMock(return_value={"hits": {"hits": []}})
|
||||
self.get = MagicMock()
|
||||
|
||||
helpers = SimpleNamespace(bulk=MagicMock())
|
||||
|
||||
opensearchpy.OpenSearch = OpenSearch
|
||||
opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth
|
||||
opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection
|
||||
opensearchpy.helpers = helpers
|
||||
opensearchpy_helpers.BulkIndexError = BulkIndexError
|
||||
|
||||
return {
|
||||
"opensearchpy": opensearchpy,
|
||||
"opensearchpy.helpers": opensearchpy_helpers,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opensearch_module(monkeypatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.opensearch.opensearch_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"host": "localhost",
|
||||
"port": 9200,
|
||||
"secure": True,
|
||||
"verify_certs": True,
|
||||
"auth_method": "basic",
|
||||
"user": "admin",
|
||||
"password": "secret",
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config OPENSEARCH_HOST is required"),
|
||||
("port", 0, "config OPENSEARCH_PORT is required"),
|
||||
],
|
||||
)
|
||||
def test_config_validation_required_fields(opensearch_module, field, value, message):
|
||||
values = _config(opensearch_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
opensearch_module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
|
||||
values = {
|
||||
"host": "localhost",
|
||||
"port": 9200,
|
||||
"secure": True,
|
||||
"verify_certs": True,
|
||||
"auth_method": "aws_managed_iam",
|
||||
"user": "admin",
|
||||
"password": "secret",
|
||||
}
|
||||
with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"):
|
||||
opensearch_module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
values = _config(opensearch_module).model_dump()
|
||||
values["OPENSEARCH_SECURE"] = False
|
||||
values["OPENSEARCH_VERIFY_CERTS"] = True
|
||||
with pytest.raises(ValidationError, match="verify_certs=True requires secure"):
|
||||
opensearch_module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
|
||||
class _Session:
|
||||
def get_credentials(self):
|
||||
return "creds"
|
||||
|
||||
boto3 = types.ModuleType("boto3")
|
||||
boto3.Session = _Session
|
||||
monkeypatch.setitem(sys.modules, "boto3", boto3)
|
||||
|
||||
config = _config(
|
||||
opensearch_module,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region="us-east-1",
|
||||
aws_service="es",
|
||||
)
|
||||
auth = config.create_aws_managed_iam_auth()
|
||||
|
||||
assert auth.credentials == "creds"
|
||||
assert auth.region == "us-east-1"
|
||||
assert auth.service == "es"
|
||||
|
||||
|
||||
def test_to_opensearch_params_supports_basic_and_aws(opensearch_module):
|
||||
basic_params = _config(opensearch_module).to_opensearch_params()
|
||||
assert basic_params["http_auth"] == ("admin", "secret")
|
||||
|
||||
aws_config = _config(
|
||||
opensearch_module,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region="us-west-2",
|
||||
aws_service="es",
|
||||
)
|
||||
with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"):
|
||||
aws_params = aws_config.to_opensearch_params()
|
||||
|
||||
assert aws_params["http_auth"] == "iam-auth"
|
||||
|
||||
|
||||
def test_init_and_create_delegate_calls(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector.get_type() == "opensearch"
|
||||
vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}])
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
|
||||
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "2"}),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id"))
|
||||
opensearch_module.helpers.bulk.reset_mock()
|
||||
vector.add_texts(docs, [[0.1], [0.2]])
|
||||
actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
|
||||
assert len(actions) == 2
|
||||
assert all("_id" in action for action in actions)
|
||||
|
||||
vector._client_config.aws_service = "aoss"
|
||||
opensearch_module.helpers.bulk.reset_mock()
|
||||
vector.add_texts(docs, [[0.3], [0.4]])
|
||||
aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
|
||||
assert all("_id" not in action for action in aoss_actions)
|
||||
|
||||
|
||||
def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
|
||||
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
vector._client.search.return_value = {"hits": {"hits": []}}
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-1"])
|
||||
|
||||
|
||||
def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
opensearch_module.helpers.bulk.reset_mock()
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
opensearch_module.helpers.bulk.assert_not_called()
|
||||
|
||||
vector._client.indices.exists.return_value = True
|
||||
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None])
|
||||
vector.delete_by_ids(["doc-1", "doc-2"])
|
||||
opensearch_module.helpers.bulk.assert_called_once()
|
||||
|
||||
opensearch_module.helpers.bulk.reset_mock()
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"])
|
||||
opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError(
|
||||
[{"delete": {"status": 404, "_id": "es-404"}}]
|
||||
)
|
||||
vector.delete_by_ids(["doc-404"])
|
||||
assert opensearch_module.helpers.bulk.call_count == 1
|
||||
|
||||
opensearch_module.helpers.bulk.side_effect = None
|
||||
|
||||
|
||||
def test_delete_and_text_exists(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
vector.delete()
|
||||
vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True)
|
||||
|
||||
vector._client.get.return_value = {"_id": "id-1"}
|
||||
assert vector.text_exists("id-1") is True
|
||||
vector._client.get.side_effect = RuntimeError("not found")
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
|
||||
def test_search_by_vector_validates_and_builds_documents(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
|
||||
with pytest.raises(ValueError, match="query_vector should be a list"):
|
||||
vector.search_by_vector("not-a-list")
|
||||
|
||||
with pytest.raises(ValueError, match="should be floats"):
|
||||
vector.search_by_vector([0.1, 1])
|
||||
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
opensearch_module.Field.CONTENT_KEY: "doc-1",
|
||||
opensearch_module.Field.METADATA_KEY: None,
|
||||
},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
opensearch_module.Field.CONTENT_KEY: "doc-2",
|
||||
opensearch_module.Field.METADATA_KEY: {"doc_id": "2"},
|
||||
},
|
||||
"_score": 0.1,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "doc-1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"])
|
||||
query = vector._client.search.call_args.kwargs["body"]
|
||||
assert "script_score" in query["query"]
|
||||
|
||||
|
||||
def test_search_by_vector_reraises_client_error(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
vector._client.search.side_effect = RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
vector.search_by_vector([0.1, 0.2])
|
||||
|
||||
|
||||
def test_search_by_full_text_and_filters(opensearch_module):
|
||||
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
|
||||
vector._client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
opensearch_module.Field.METADATA_KEY: {"doc_id": "1"},
|
||||
opensearch_module.Field.VECTOR: [0.1],
|
||||
opensearch_module.Field.CONTENT_KEY: "matched text",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "matched text"
|
||||
query = vector._client.search.call_args.kwargs["body"]
|
||||
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
|
||||
|
||||
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._client.indices.create.reset_mock()
|
||||
vector.create_collection([[0.1, 0.2]])
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.indices.exists.return_value = False
|
||||
vector.create_collection([[0.1, 0.2]])
|
||||
vector._client.indices.create.assert_called_once()
|
||||
index_body = vector._client.indices.create.call_args.kwargs["body"]
|
||||
assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2
|
||||
opensearch_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
|
||||
factory = opensearch_module.OpenSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost")
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200)
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True)
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True)
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic")
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin")
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret")
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None)
|
||||
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None)
|
||||
|
||||
with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,375 @@
|
||||
import array
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_oracle_modules():
|
||||
jieba = types.ModuleType("jieba")
|
||||
jieba_posseg = types.ModuleType("jieba.posseg")
|
||||
jieba_posseg.cut = MagicMock(return_value=[])
|
||||
jieba.posseg = jieba_posseg
|
||||
|
||||
oracledb = types.ModuleType("oracledb")
|
||||
oracledb_connection = types.ModuleType("oracledb.connection")
|
||||
|
||||
class Connection:
|
||||
pass
|
||||
|
||||
oracledb_connection.Connection = Connection
|
||||
oracledb.defaults = SimpleNamespace(fetch_lobs=True)
|
||||
oracledb.DB_TYPE_VECTOR = object()
|
||||
oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock()))
|
||||
oracledb.connect = MagicMock()
|
||||
|
||||
return {
|
||||
"jieba": jieba,
|
||||
"jieba.posseg": jieba_posseg,
|
||||
"oracledb": oracledb,
|
||||
"oracledb.connection": oracledb_connection,
|
||||
}
|
||||
|
||||
|
||||
def _connection_with_cursor(cursor):
|
||||
cursor_ctx = MagicMock()
|
||||
cursor_ctx.__enter__.return_value = cursor
|
||||
cursor_ctx.__exit__.return_value = None
|
||||
|
||||
connection = MagicMock()
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = None
|
||||
connection.cursor.return_value = cursor_ctx
|
||||
return connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oracle_module(monkeypatch):
|
||||
for name, module in _build_fake_oracle_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.oracle.oraclevector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"user": "system",
|
||||
"password": "oracle",
|
||||
"dsn": "oracle:1521/freepdb1",
|
||||
"is_autonomous": False,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.OracleVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("user", "", "config ORACLE_USER is required"),
|
||||
("password", "", "config ORACLE_PASSWORD is required"),
|
||||
("dsn", "", "config ORACLE_DSN is required"),
|
||||
],
|
||||
)
|
||||
def test_oracle_config_validation_required_fields(oracle_module, field, value, message):
|
||||
values = _config(oracle_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
oracle_module.OracleVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_oracle_config_validation_autonomous_requirements(oracle_module):
|
||||
with pytest.raises(ValidationError, match="config_dir is required"):
|
||||
oracle_module.OracleVectorConfig.model_validate(
|
||||
{"user": "u", "password": "p", "dsn": "d", "is_autonomous": True}
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_type(oracle_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
|
||||
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
|
||||
|
||||
assert vector.get_type() == "oracle"
|
||||
assert vector.table_name == "embedding_collection_1"
|
||||
assert vector.pool is pool
|
||||
|
||||
|
||||
def test_numpy_converters_and_type_handlers(oracle_module):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
|
||||
in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64))
|
||||
in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32))
|
||||
in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8))
|
||||
assert in_float64.typecode == "d"
|
||||
assert in_float32.typecode == "f"
|
||||
assert in_int8.typecode == "b"
|
||||
|
||||
cursor = MagicMock()
|
||||
vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2)
|
||||
cursor.var.assert_called_with(
|
||||
oracle_module.oracledb.DB_TYPE_VECTOR,
|
||||
arraysize=2,
|
||||
inconverter=vector.numpy_converter_in,
|
||||
)
|
||||
|
||||
metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR)
|
||||
cursor.arraysize = 3
|
||||
vector.output_type_handler(cursor, metadata)
|
||||
cursor.var.assert_called_with(
|
||||
metadata.type_code,
|
||||
arraysize=3,
|
||||
outconverter=vector.numpy_converter_out,
|
||||
)
|
||||
|
||||
out_int8 = vector.numpy_converter_out(array.array("b", [1]))
|
||||
assert out_int8.dtype == numpy.int8
|
||||
out_float32 = vector.numpy_converter_out(array.array("f", [1.0]))
|
||||
assert out_float32.dtype == numpy.float32
|
||||
out_float64 = vector.numpy_converter_out(array.array("d", [1.0]))
|
||||
assert out_float64.dtype == numpy.float64
|
||||
|
||||
|
||||
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
|
||||
connect = MagicMock(return_value="connection")
|
||||
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
|
||||
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.config = _config(oracle_module)
|
||||
assert vector._get_connection() == "connection"
|
||||
connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1")
|
||||
|
||||
vector.config = _config(
|
||||
oracle_module,
|
||||
is_autonomous=True,
|
||||
config_dir="/wallet",
|
||||
wallet_location="/wallet",
|
||||
wallet_password="pw",
|
||||
)
|
||||
vector._get_connection()
|
||||
assert connect.call_args.kwargs["config_dir"] == "/wallet"
|
||||
assert connect.call_args.kwargs["wallet_location"] == "/wallet"
|
||||
|
||||
|
||||
def test_create_delegates_collection_and_insert(oracle_module):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock(return_value=["seg-1"])
|
||||
docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
result = vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert result == ["seg-1"]
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector.input_type_handler = MagicMock()
|
||||
vector.output_type_handler = MagicMock()
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.execute.side_effect = [None, RuntimeError("insert failed")]
|
||||
connection = _connection_with_cursor(cursor)
|
||||
vector._get_connection = MagicMock(return_value=connection)
|
||||
|
||||
monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "doc-a"}),
|
||||
Document(page_content="b", metadata={"document_id": "doc-b"}),
|
||||
SimpleNamespace(page_content="c", metadata=None),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
|
||||
assert ids == ["doc-a", "generated-uuid"]
|
||||
assert cursor.execute.call_count == 2
|
||||
assert connection.commit.call_count >= 1
|
||||
connection.close.assert_called()
|
||||
|
||||
|
||||
def test_text_exists_and_get_by_ids(oracle_module):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector.pool = MagicMock()
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.fetchone.return_value = ("id-1",)
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
|
||||
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
|
||||
|
||||
assert vector.text_exists("id-1") is True
|
||||
docs = vector.get_by_ids(["id-1", "id-2"])
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "text-1"
|
||||
vector.pool.release.assert_called_once()
|
||||
assert vector.get_by_ids([]) == []
|
||||
|
||||
|
||||
def test_delete_methods(oracle_module):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
cursor = MagicMock()
|
||||
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._get_connection.assert_not_called()
|
||||
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete()
|
||||
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql)
|
||||
assert any("JSON_VALUE(meta" in sql for sql in executed_sql)
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
|
||||
|
||||
def test_search_by_vector_with_threshold_and_filter(oracle_module):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector.input_type_handler = MagicMock()
|
||||
vector.output_type_handler = MagicMock()
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)])
|
||||
connection = _connection_with_cursor(cursor)
|
||||
vector._get_connection = MagicMock(return_value=connection)
|
||||
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
top_k=0,
|
||||
score_threshold=0.5,
|
||||
document_ids_filter=["d-1", "d-2"],
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
sql = cursor.execute.call_args.args[0]
|
||||
assert "fetch first 4 rows only" in sql
|
||||
assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql
|
||||
|
||||
|
||||
def _fake_nltk_module(*, missing_data=False):
|
||||
nltk = types.ModuleType("nltk")
|
||||
nltk_corpus = types.ModuleType("nltk.corpus")
|
||||
|
||||
class _Data:
|
||||
@staticmethod
|
||||
def find(_path):
|
||||
if missing_data:
|
||||
raise LookupError("missing")
|
||||
return True
|
||||
|
||||
nltk.data = _Data()
|
||||
nltk.word_tokenize = lambda text: text.split()
|
||||
nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"])
|
||||
return nltk, nltk_corpus
|
||||
|
||||
|
||||
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])])
|
||||
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
|
||||
|
||||
monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")]))
|
||||
zh_docs = vector.search_by_full_text("张三", top_k=2)
|
||||
assert len(zh_docs) == 1
|
||||
zh_params = cursor.execute.call_args.args[1]
|
||||
assert zh_params["kk"] == "张三"
|
||||
|
||||
nltk, nltk_corpus = _fake_nltk_module(missing_data=False)
|
||||
monkeypatch.setitem(sys.modules, "nltk", nltk)
|
||||
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])])
|
||||
en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"])
|
||||
assert len(en_docs) == 1
|
||||
en_sql = cursor.execute.call_args.args[0]
|
||||
en_params = cursor.execute.call_args.args[1]
|
||||
assert "fetch first 5 rows only" in en_sql
|
||||
assert "doc_id_0" in en_params
|
||||
|
||||
|
||||
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._get_connection = MagicMock()
|
||||
|
||||
empty_result = vector.search_by_full_text("")
|
||||
assert empty_result[0].page_content == ""
|
||||
|
||||
nltk, nltk_corpus = _fake_nltk_module(missing_data=True)
|
||||
monkeypatch.setitem(sys.modules, "nltk", nltk)
|
||||
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
|
||||
with pytest.raises(LookupError, match="required NLTK data package"):
|
||||
vector.search_by_full_text("english query")
|
||||
|
||||
|
||||
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
cursor = MagicMock()
|
||||
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
|
||||
|
||||
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(2)
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._create_collection(2)
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql)
|
||||
oracle_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
|
||||
factory = oracle_module.OracleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system")
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle")
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1")
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None)
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None)
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None)
|
||||
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False)
|
||||
|
||||
with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,317 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.types import UserDefinedType
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_pgvecto_modules():
|
||||
pgvecto_rs = types.ModuleType("pgvecto_rs")
|
||||
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
|
||||
|
||||
class VECTOR(UserDefinedType):
|
||||
def __init__(self, dim):
|
||||
self.dim = dim
|
||||
|
||||
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
|
||||
return {
|
||||
"pgvecto_rs": pgvecto_rs,
|
||||
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
|
||||
}
|
||||
|
||||
|
||||
class _FakeSessionContext:
|
||||
def __init__(self, calls, execute_results=None):
|
||||
self.calls = calls
|
||||
self.execute_results = execute_results or []
|
||||
self.execute = MagicMock(side_effect=self._execute_side_effect)
|
||||
self.commit = MagicMock()
|
||||
|
||||
def _execute_side_effect(self, *args, **kwargs):
|
||||
self.calls.append((args, kwargs))
|
||||
if self.execute_results:
|
||||
return self.execute_results.pop(0)
|
||||
return MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(calls, execute_results=None):
|
||||
def _session(_client):
|
||||
return _FakeSessionContext(calls=calls, execute_results=execute_results)
|
||||
|
||||
return _session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module
|
||||
import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module
|
||||
|
||||
return importlib.reload(module), importlib.reload(collection_module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "postgres",
|
||||
"password": "secret",
|
||||
"database": "postgres",
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.PgvectoRSConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config PGVECTO_RS_HOST is required"),
|
||||
("port", 0, "config PGVECTO_RS_PORT is required"),
|
||||
("user", "", "config PGVECTO_RS_USER is required"),
|
||||
("password", "", "config PGVECTO_RS_PASSWORD is required"),
|
||||
("database", "", "config PGVECTO_RS_DATABASE is required"),
|
||||
],
|
||||
)
|
||||
def test_pgvecto_config_validation(pgvecto_module, field, value, message):
|
||||
module, _ = pgvecto_module
|
||||
values = _config(module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
module.PgvectoRSConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_collection_base_has_expected_annotations(pgvecto_module):
|
||||
_, collection_module = pgvecto_module
|
||||
annotations = collection_module.CollectionORM.__annotations__
|
||||
assert {"id", "text", "meta", "vector"} <= set(annotations)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector.get_type() == module.VectorType.PGVECTO_RS
|
||||
module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres")
|
||||
assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls)
|
||||
vector.create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection(3)
|
||||
assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
|
||||
|
||||
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector.create_collection(3)
|
||||
assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
|
||||
assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls)
|
||||
module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
runtime_calls = []
|
||||
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
|
||||
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
|
||||
|
||||
class _InsertBuilder:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
|
||||
def values(self, **kwargs):
|
||||
return ("insert", kwargs)
|
||||
|
||||
monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table))
|
||||
monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"]))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "2"}),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
assert ids == ["uuid-1", "uuid-2"]
|
||||
assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0])
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
|
||||
)
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]),
|
||||
)
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Session",
|
||||
_session_factory(
|
||||
runtime_calls,
|
||||
execute_results=[
|
||||
SimpleNamespace(fetchall=lambda: [("row-id-1",)]),
|
||||
MagicMock(),
|
||||
],
|
||||
),
|
||||
)
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
|
||||
vector.delete()
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
|
||||
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
runtime_calls = []
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Session",
|
||||
_session_factory(
|
||||
runtime_calls,
|
||||
execute_results=[
|
||||
SimpleNamespace(fetchall=lambda: [("id-1",)]),
|
||||
SimpleNamespace(fetchall=lambda: []),
|
||||
],
|
||||
),
|
||||
)
|
||||
assert vector.text_exists("doc-1") is True
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
class _DistanceExpr:
|
||||
def label(self, _name):
|
||||
return self
|
||||
|
||||
class _VectorColumn:
|
||||
def op(self, _operator, return_type=None):
|
||||
def _call(_query_vector):
|
||||
return _DistanceExpr()
|
||||
|
||||
return _call
|
||||
|
||||
class _MetaFilter:
|
||||
def in_(self, values):
|
||||
return ("in", values)
|
||||
|
||||
class _MetaColumn:
|
||||
def __getitem__(self, _item):
|
||||
return _MetaFilter()
|
||||
|
||||
class _Stmt:
|
||||
def __init__(self):
|
||||
self.where_called = False
|
||||
|
||||
def limit(self, _value):
|
||||
return self
|
||||
|
||||
def order_by(self, _value):
|
||||
return self
|
||||
|
||||
def where(self, _value):
|
||||
self.where_called = True
|
||||
return self
|
||||
|
||||
stmt = _Stmt()
|
||||
monkeypatch.setattr(module, "select", lambda *_args: stmt)
|
||||
|
||||
vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn())
|
||||
rows = [
|
||||
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
|
||||
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
|
||||
]
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
assert stmt.where_called is True
|
||||
assert vector.search_by_full_text("hello") == []
|
||||
|
||||
|
||||
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
factory = module.PGVectoRSFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost")
|
||||
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432)
|
||||
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres")
|
||||
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret")
|
||||
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres")
|
||||
|
||||
embeddings = MagicMock()
|
||||
embeddings.embed_query.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings)
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings)
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -1,16 +1,19 @@
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class TestPGVector:
|
||||
def setup_method(self, method):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
def test_create_delegates_collection_creation_and_insert():
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock(return_value=["doc-a"])
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})]
|
||||
|
||||
result = vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert result == ["doc-a"]
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
execute_values = MagicMock()
|
||||
monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values)
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "doc-a"}),
|
||||
Document(page_content="b", metadata={"document_id": "doc-b"}),
|
||||
SimpleNamespace(page_content="c", metadata=None),
|
||||
]
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
|
||||
assert ids == ["doc-a", "generated-uuid"]
|
||||
execute_values.assert_called_once()
|
||||
|
||||
|
||||
def test_text_get_and_delete_methods():
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.fetchone.return_value = ("id-1",)
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
assert vector.text_exists("id-1") is True
|
||||
docs = vector.get_by_ids(["id-1", "id-2"])
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "text-1"
|
||||
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete()
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("meta->>%s = %s" in sql for sql in executed_sql)
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
|
||||
|
||||
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
vector.delete_by_ids([])
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
class _UndefinedTableError(Exception):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError)
|
||||
cursor.execute.side_effect = _UndefinedTableError("missing")
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
|
||||
cursor.execute.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
|
||||
|
||||
def test_search_by_vector_supports_filter_and_threshold():
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)])
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_vector([0.1], top_k=0)
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
sql = cursor.execute.call_args.args[0]
|
||||
assert "meta->>'document_id' in ('d-1')" in sql
|
||||
|
||||
|
||||
def test_search_by_full_text_branches_for_bigm_and_standard():
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)])
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_full_text("hello", top_k=0)
|
||||
|
||||
vector.pg_bigm = False
|
||||
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.7)
|
||||
standard_sql = cursor.execute.call_args.args[0]
|
||||
assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql
|
||||
|
||||
cursor.execute.reset_mock()
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)])
|
||||
vector.pg_bigm = True
|
||||
vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"])
|
||||
assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0]
|
||||
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
|
||||
|
||||
|
||||
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
|
||||
factory = pgvector_module.PGVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost")
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432)
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres")
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret")
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres")
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1)
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5)
|
||||
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False)
|
||||
|
||||
with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
|
||||
@ -0,0 +1,269 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_psycopg2_modules():
|
||||
psycopg2 = types.ModuleType("psycopg2")
|
||||
psycopg2.__path__ = []
|
||||
psycopg2_extras = types.ModuleType("psycopg2.extras")
|
||||
psycopg2_pool = types.ModuleType("psycopg2.pool")
|
||||
|
||||
class SimpleConnectionPool:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.getconn = MagicMock()
|
||||
self.putconn = MagicMock()
|
||||
|
||||
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
|
||||
psycopg2_extras.execute_values = MagicMock()
|
||||
psycopg2.pool = psycopg2_pool
|
||||
psycopg2.extras = psycopg2_extras
|
||||
|
||||
return {
|
||||
"psycopg2": psycopg2,
|
||||
"psycopg2.pool": psycopg2_pool,
|
||||
"psycopg2.extras": psycopg2_extras,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vastbase_module(monkeypatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.VastbaseVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="dify",
|
||||
password="secret",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config VASTBASE_HOST is required"),
|
||||
("port", 0, "config VASTBASE_PORT is required"),
|
||||
("user", "", "config VASTBASE_USER is required"),
|
||||
("password", "", "config VASTBASE_PASSWORD is required"),
|
||||
("database", "", "config VASTBASE_DATABASE is required"),
|
||||
("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"),
|
||||
("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"),
|
||||
],
|
||||
)
|
||||
def test_vastbase_config_validation(vastbase_module, field, value, message):
|
||||
values = _config(vastbase_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
vastbase_module.VastbaseVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
|
||||
with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"):
|
||||
vastbase_module.VastbaseVectorConfig.model_validate(
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "dify",
|
||||
"password": "secret",
|
||||
"database": "dify",
|
||||
"min_connection": 6,
|
||||
"max_connection": 5,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
conn = MagicMock()
|
||||
cur = MagicMock()
|
||||
pool.getconn.return_value = conn
|
||||
conn.cursor.return_value = cur
|
||||
|
||||
vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module))
|
||||
assert vector.get_type() == "vastbase"
|
||||
assert vector.table_name == "embedding_collection_1"
|
||||
|
||||
with vector._get_cursor() as got_cur:
|
||||
assert got_cur is cur
|
||||
|
||||
cur.close.assert_called_once()
|
||||
conn.commit.assert_called_once()
|
||||
pool.putconn.assert_called_once_with(conn)
|
||||
|
||||
|
||||
def test_create_and_add_texts(vastbase_module, monkeypatch):
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._create_collection = MagicMock()
|
||||
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "doc-a"}),
|
||||
Document(page_content="b", metadata={"document_id": "doc-b"}),
|
||||
SimpleNamespace(page_content="c", metadata=None),
|
||||
]
|
||||
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
|
||||
assert ids == ["doc-a", "generated-uuid"]
|
||||
vastbase_module.psycopg2.extras.execute_values.assert_called_once()
|
||||
|
||||
vector.add_texts = MagicMock(return_value=["doc-a"])
|
||||
result = vector.create(docs, [[0.1], [0.2], [0.3]])
|
||||
vector._create_collection.assert_called_once_with(1)
|
||||
assert result == ["doc-a"]
|
||||
|
||||
|
||||
def test_text_get_delete_and_metadata_methods(vastbase_module):
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.fetchone.return_value = ("id-1",)
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
assert vector.text_exists("id-1") is True
|
||||
docs = vector.get_by_ids(["id-1", "id-2"])
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "text-1"
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete()
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql)
|
||||
assert any("meta->>%s = %s" in sql for sql in executed_sql)
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
|
||||
|
||||
def test_search_by_vector_and_full_text(vastbase_module):
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = iter(
|
||||
[
|
||||
({"doc_id": "1"}, "text-1", 0.1),
|
||||
({"doc_id": "2"}, "text-2", 0.8),
|
||||
]
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_vector([0.1, 0.2], top_k=0)
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer"):
|
||||
vector.search_by_full_text("hello", top_k=0)
|
||||
|
||||
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)])
|
||||
full_docs = vector.search_by_full_text("hello world", top_k=2)
|
||||
assert len(full_docs) == 1
|
||||
assert full_docs[0].page_content == "full-text"
|
||||
|
||||
|
||||
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def _cursor_ctx():
|
||||
yield cursor
|
||||
|
||||
vector._get_cursor = _cursor_ctx
|
||||
|
||||
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(3)
|
||||
cursor.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._create_collection(17000)
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql)
|
||||
|
||||
cursor.execute.reset_mock()
|
||||
vector._create_collection(3)
|
||||
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql)
|
||||
vastbase_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
|
||||
factory = vastbase_module.VastbaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost")
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432)
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify")
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret")
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify")
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1)
|
||||
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5)
|
||||
|
||||
with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,328 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from collections import UserDict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_qdrant_modules():
|
||||
qdrant_client = types.ModuleType("qdrant_client")
|
||||
qdrant_http = types.ModuleType("qdrant_client.http")
|
||||
qdrant_http_models = types.ModuleType("qdrant_client.http.models")
|
||||
qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions")
|
||||
qdrant_local_pkg = types.ModuleType("qdrant_client.local")
|
||||
qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local")
|
||||
|
||||
class UnexpectedResponseError(Exception):
|
||||
def __init__(self, status_code):
|
||||
super().__init__(f"status={status_code}")
|
||||
self.status_code = status_code
|
||||
|
||||
class FilterSelector:
|
||||
def __init__(self, filter):
|
||||
self.filter = filter
|
||||
|
||||
class HnswConfigDiff:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class TextIndexParams:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class VectorParams:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class PointStruct:
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs["id"]
|
||||
self.vector = kwargs["vector"]
|
||||
self.payload = kwargs["payload"]
|
||||
|
||||
class Filter:
|
||||
def __init__(self, must=None):
|
||||
self.must = must or []
|
||||
|
||||
class FieldCondition:
|
||||
def __init__(self, key, match):
|
||||
self.key = key
|
||||
self.match = match
|
||||
|
||||
class MatchValue:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class MatchAny:
|
||||
def __init__(self, any):
|
||||
self.any = any
|
||||
|
||||
class MatchText:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
class _Distance(UserDict):
|
||||
def __getitem__(self, key):
|
||||
return key
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[]))
|
||||
self.create_collection = MagicMock()
|
||||
self.create_payload_index = MagicMock()
|
||||
self.upsert = MagicMock()
|
||||
self.delete = MagicMock()
|
||||
self.delete_collection = MagicMock()
|
||||
self.retrieve = MagicMock(return_value=[])
|
||||
self.search = MagicMock(return_value=[])
|
||||
self.scroll = MagicMock(return_value=([], None))
|
||||
|
||||
class QdrantLocal(QdrantClient):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._load = MagicMock()
|
||||
|
||||
qdrant_client.QdrantClient = QdrantClient
|
||||
qdrant_http_models.FilterSelector = FilterSelector
|
||||
qdrant_http_models.HnswConfigDiff = HnswConfigDiff
|
||||
qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD")
|
||||
qdrant_http_models.TextIndexParams = TextIndexParams
|
||||
qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT")
|
||||
qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL")
|
||||
qdrant_http_models.VectorParams = VectorParams
|
||||
qdrant_http_models.Distance = _Distance()
|
||||
qdrant_http_models.PointStruct = PointStruct
|
||||
qdrant_http_models.Filter = Filter
|
||||
qdrant_http_models.FieldCondition = FieldCondition
|
||||
qdrant_http_models.MatchValue = MatchValue
|
||||
qdrant_http_models.MatchAny = MatchAny
|
||||
qdrant_http_models.MatchText = MatchText
|
||||
qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError
|
||||
|
||||
qdrant_http.models = qdrant_http_models
|
||||
qdrant_local_mod.QdrantLocal = QdrantLocal
|
||||
qdrant_local_pkg.qdrant_local = qdrant_local_mod
|
||||
|
||||
return {
|
||||
"qdrant_client": qdrant_client,
|
||||
"qdrant_client.http": qdrant_http,
|
||||
"qdrant_client.http.models": qdrant_http_models,
|
||||
"qdrant_client.http.exceptions": qdrant_http_exceptions,
|
||||
"qdrant_client.local": qdrant_local_pkg,
|
||||
"qdrant_client.local.qdrant_local": qdrant_local_mod,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_module(monkeypatch):
|
||||
for name, module in _build_fake_qdrant_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.qdrant.qdrant_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"endpoint": "http://localhost:6333",
|
||||
"api_key": "api-key",
|
||||
"timeout": 20,
|
||||
"root_path": "/tmp",
|
||||
"grpc_port": 6334,
|
||||
"prefer_grpc": False,
|
||||
"replication_factor": 1,
|
||||
"write_consistency_factor": 1,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.QdrantConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_qdrant_config_to_params(qdrant_module):
|
||||
url_params = _config(qdrant_module).to_qdrant_params().model_dump()
|
||||
assert url_params["url"] == "http://localhost:6333"
|
||||
assert url_params["verify"] is False
|
||||
|
||||
path_config = _config(qdrant_module, endpoint="path:storage")
|
||||
assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage")
|
||||
|
||||
with pytest.raises(ValueError, match="Root path is not set"):
|
||||
_config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params()
|
||||
|
||||
|
||||
def test_init_and_basic_behaviour(qdrant_module):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
assert vector.get_type() == qdrant_module.VectorType.QDRANT
|
||||
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1"
|
||||
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "a"})]
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
vector.create(docs, [[0.1]])
|
||||
vector.create_collection.assert_called_once_with("collection_1", 1)
|
||||
vector.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock())
|
||||
|
||||
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector.create_collection("collection_1", 3)
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.get_collections.return_value = SimpleNamespace(collections=[])
|
||||
vector.create_collection("collection_1", 3)
|
||||
vector._client.create_collection.assert_called_once()
|
||||
assert vector._client.create_payload_index.call_count == 4
|
||||
qdrant_module.redis_client.set.assert_called_once()
|
||||
|
||||
# add_texts and generated batches
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert vector._client.upsert.call_count == 1
|
||||
|
||||
payloads = qdrant_module.QdrantVector._build_payloads(
|
||||
["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
|
||||
)
|
||||
assert payloads[0]["group_id"] == "g1"
|
||||
with pytest.raises(ValueError, match="At least one of the texts is None"):
|
||||
qdrant_module.QdrantVector._build_payloads(
|
||||
[None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
|
||||
)
|
||||
|
||||
|
||||
def test_delete_and_exists_paths(qdrant_module):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse
|
||||
|
||||
vector._client.delete.side_effect = unexpected(404)
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.delete.side_effect = None
|
||||
|
||||
vector._client.delete.side_effect = unexpected(500)
|
||||
with pytest.raises(unexpected):
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector._client.delete.side_effect = None
|
||||
|
||||
vector._client.delete.side_effect = unexpected(404)
|
||||
vector.delete()
|
||||
vector._client.delete.side_effect = unexpected(500)
|
||||
with pytest.raises(unexpected):
|
||||
vector.delete()
|
||||
vector._client.delete.side_effect = None
|
||||
|
||||
vector._client.delete.side_effect = unexpected(404)
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
vector._client.delete.side_effect = unexpected(500)
|
||||
with pytest.raises(unexpected):
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
vector._client.delete.side_effect = None
|
||||
|
||||
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")])
|
||||
assert vector.text_exists("id-1") is False
|
||||
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")])
|
||||
vector._client.retrieve.return_value = [{"id": "id-1"}]
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
|
||||
def test_search_and_helper_methods(qdrant_module):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
assert vector.search_by_vector([0.1], score_threshold=1.0) == []
|
||||
|
||||
vector._client.search.return_value = [
|
||||
SimpleNamespace(payload=None, score=0.9, vector=[0.1]),
|
||||
SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]),
|
||||
]
|
||||
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
|
||||
# full text search: keyword split, dedup and top_k limit
|
||||
scroll_results = [
|
||||
(
|
||||
[
|
||||
SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]),
|
||||
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
|
||||
],
|
||||
None,
|
||||
),
|
||||
]
|
||||
vector._client.scroll.side_effect = scroll_results
|
||||
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 2
|
||||
assert vector.search_by_full_text(" ", top_k=2) == []
|
||||
|
||||
local_client = qdrant_module.QdrantLocal()
|
||||
vector._client = local_client
|
||||
vector._reload_if_needed()
|
||||
local_client._load.assert_called_once()
|
||||
|
||||
doc = vector._document_from_scored_point(
|
||||
SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]),
|
||||
"page_content",
|
||||
"metadata",
|
||||
)
|
||||
assert doc.page_content == "doc"
|
||||
|
||||
|
||||
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
|
||||
factory = qdrant_module.QdrantVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
collection_binding_id=None,
|
||||
index_struct_dict=None,
|
||||
index_struct=None,
|
||||
)
|
||||
monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root")))
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333")
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key")
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20)
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334)
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False)
|
||||
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1)
|
||||
|
||||
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
|
||||
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
assert result == "vector"
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
# collection binding lookup path
|
||||
dataset.collection_binding_id = "binding-1"
|
||||
dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}}
|
||||
monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
|
||||
qdrant_module.db.session.scalars = MagicMock(
|
||||
return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION"))
|
||||
)
|
||||
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
|
||||
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION"
|
||||
|
||||
qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None))
|
||||
with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"):
|
||||
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
@ -0,0 +1,303 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.types import UserDefinedType
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_relyt_modules():
|
||||
pgvecto_rs = types.ModuleType("pgvecto_rs")
|
||||
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
|
||||
|
||||
class VECTOR(UserDefinedType):
|
||||
def __init__(self, dim):
|
||||
self.dim = dim
|
||||
|
||||
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
|
||||
return {
|
||||
"pgvecto_rs": pgvecto_rs,
|
||||
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
|
||||
}
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, execute_result=None):
|
||||
self.execute_result = execute_result or MagicMock(fetchall=lambda: [])
|
||||
self.execute = MagicMock(return_value=self.execute_result)
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.relyt.relyt_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "postgres",
|
||||
"password": "secret",
|
||||
"database": "relyt",
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.RelytConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config RELYT_HOST is required"),
|
||||
("port", 0, "config RELYT_PORT is required"),
|
||||
("user", "", "config RELYT_USER is required"),
|
||||
("password", "", "config RELYT_PASSWORD is required"),
|
||||
("database", "", "config RELYT_DATABASE is required"),
|
||||
],
|
||||
)
|
||||
def test_relyt_config_validation(relyt_module, field, value, message):
|
||||
values = _config(relyt_module).model_dump()
|
||||
values[field] = value
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
relyt_module.RelytConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
|
||||
engine = MagicMock()
|
||||
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
|
||||
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
|
||||
vector.create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
|
||||
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
assert vector.get_type() == relyt_module.VectorType.RELYT
|
||||
assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt"
|
||||
assert vector.embedding_dimension == 2
|
||||
vector.create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
vector.create_collection(3)
|
||||
session.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
vector.create_collection(3)
|
||||
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
||||
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
||||
assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql)
|
||||
assert any("CREATE INDEX" in sql for sql in executed_sql)
|
||||
relyt_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._group_id = "group-1"
|
||||
vector.client = MagicMock()
|
||||
|
||||
begin_ctx = MagicMock()
|
||||
begin_ctx.__enter__.return_value = None
|
||||
begin_ctx.__exit__.return_value = None
|
||||
conn = MagicMock()
|
||||
conn.__enter__.return_value = conn
|
||||
conn.__exit__.return_value = None
|
||||
conn.begin.return_value = begin_ctx
|
||||
vector.client.connect.return_value = conn
|
||||
|
||||
monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"]))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "d-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "d-2"}),
|
||||
]
|
||||
ids = vector.add_texts(docs, [[0.1], [0.2]])
|
||||
|
||||
assert ids == ["id-1", "id-2"]
|
||||
assert conn.execute.call_count >= 1
|
||||
first_insert_values = conn.execute.call_args.args[0].compile().params
|
||||
assert "group_id" in str(first_insert_values)
|
||||
|
||||
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)]))
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"]
|
||||
|
||||
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
|
||||
|
||||
|
||||
# 1. delete_by_uuids: success and connect error
|
||||
def test_delete_by_uuids_success_and_connect_error(relyt_module):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
with pytest.raises(ValueError, match="No ids provided"):
|
||||
vector.delete_by_uuids(None)
|
||||
conn = MagicMock()
|
||||
conn.__enter__.return_value = conn
|
||||
conn.__exit__.return_value = None
|
||||
begin_ctx = MagicMock()
|
||||
begin_ctx.__enter__.return_value = None
|
||||
begin_ctx.__exit__.return_value = None
|
||||
conn.begin.return_value = begin_ctx
|
||||
vector.client.connect.return_value = conn
|
||||
assert vector.delete_by_uuids(["id-1"]) is True
|
||||
vector.client.connect.side_effect = RuntimeError("boom")
|
||||
assert vector.delete_by_uuids(["id-1"]) is False
|
||||
|
||||
|
||||
# 2. delete_by_metadata_field calls delete_by_uuids
|
||||
def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
vector.delete_by_uuids = MagicMock(return_value=True)
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_uuids.assert_called_once_with(["id-1"])
|
||||
|
||||
|
||||
# 3. delete_by_ids translates to uuids
|
||||
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)]))
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
vector.delete_by_uuids = MagicMock(return_value=True)
|
||||
vector.delete_by_ids(["doc-1", "doc-2"])
|
||||
vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"])
|
||||
|
||||
|
||||
# 4. text_exists True
|
||||
def test_text_exists_true(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)]))
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
assert vector.text_exists("doc-1") is True
|
||||
|
||||
|
||||
# 5. text_exists False
|
||||
def test_text_exists_false(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
|
||||
# 6. similarity_search_with_score_by_vector returns Documents and scores
|
||||
def test_similarity_search_with_score_by_vector(relyt_module):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
result_rows = [
|
||||
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1),
|
||||
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8),
|
||||
]
|
||||
conn = MagicMock()
|
||||
conn.__enter__.return_value = conn
|
||||
conn.__exit__.return_value = None
|
||||
conn.execute.return_value.fetchall.return_value = result_rows
|
||||
vector.client.connect.return_value = conn
|
||||
similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]})
|
||||
assert len(similarities) == 2
|
||||
assert similarities[0][0].page_content == "doc-a"
|
||||
|
||||
|
||||
# 7. search_by_vector filters by score and ids
|
||||
def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
vector.similarity_search_with_score_by_vector = MagicMock(
|
||||
return_value=[
|
||||
(Document(page_content="a", metadata={"doc_id": "1"}), 0.1),
|
||||
(Document(page_content="b", metadata={}), 0.9),
|
||||
]
|
||||
)
|
||||
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_commits_session(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
vector.delete()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
factory = relyt_module.RelytVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost")
|
||||
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432)
|
||||
monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres")
|
||||
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret")
|
||||
monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt")
|
||||
|
||||
with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,316 @@
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_tablestore_module():
|
||||
tablestore = types.ModuleType("tablestore")
|
||||
|
||||
class _BatchGetRowRequest:
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
|
||||
def add(self, item):
|
||||
self.items.append(item)
|
||||
|
||||
class _TableInBatchGetRowItem:
|
||||
def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver):
|
||||
self.table_name = table_name
|
||||
self.rows_to_get = rows_to_get
|
||||
self.columns_to_get = columns_to_get
|
||||
|
||||
class _Row:
|
||||
def __init__(self, primary_key, attribute_columns=None):
|
||||
self.primary_key = primary_key
|
||||
self.attribute_columns = attribute_columns or []
|
||||
|
||||
class _Client:
|
||||
def __init__(self, *_args):
|
||||
self.list_table = MagicMock(return_value=[])
|
||||
self.create_table = MagicMock()
|
||||
self.list_search_index = MagicMock(return_value=[])
|
||||
self.create_search_index = MagicMock()
|
||||
self.delete_search_index = MagicMock()
|
||||
self.delete_table = MagicMock()
|
||||
self.put_row = MagicMock()
|
||||
self.delete_row = MagicMock()
|
||||
self.get_row = MagicMock(return_value=(None, None, None))
|
||||
self.batch_get_row = MagicMock()
|
||||
self.search = MagicMock()
|
||||
|
||||
tablestore.OTSClient = _Client
|
||||
tablestore.BatchGetRowRequest = _BatchGetRowRequest
|
||||
tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem
|
||||
tablestore.Row = _Row
|
||||
tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema)
|
||||
tablestore.TableOptions = lambda: ("table_options",)
|
||||
tablestore.CapacityUnit = lambda read, write: ("capacity", read, write)
|
||||
tablestore.ReservedThroughput = lambda cap: ("reserved", cap)
|
||||
tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs)
|
||||
tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs)
|
||||
tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas)
|
||||
tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs)
|
||||
tablestore.TermQuery = lambda key, value: ("term_query", key, value)
|
||||
tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs)
|
||||
tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs)
|
||||
tablestore.TermsQuery = lambda key, values: ("terms_query", key, values)
|
||||
tablestore.Sort = lambda **kwargs: ("sort", kwargs)
|
||||
tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs)
|
||||
tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs)
|
||||
tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs)
|
||||
|
||||
tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD")
|
||||
tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD")
|
||||
tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32")
|
||||
tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE")
|
||||
tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX")
|
||||
tablestore.SortOrder = SimpleNamespace(DESC="DESC")
|
||||
return tablestore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tablestore_module(monkeypatch):
|
||||
fake_module = _build_fake_tablestore_module()
|
||||
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
|
||||
|
||||
import core.rag.datasource.vdb.tablestore.tablestore_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"access_key_id": "ak",
|
||||
"access_key_secret": "sk",
|
||||
"instance_name": "instance",
|
||||
"endpoint": "endpoint",
|
||||
"normalize_full_text_bm25_score": False,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.TableStoreConfig.model_validate(values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("access_key_id", "", "config ACCESS_KEY_ID is required"),
|
||||
("access_key_secret", "", "config ACCESS_KEY_SECRET is required"),
|
||||
("instance_name", "", "config INSTANCE_NAME is required"),
|
||||
("endpoint", "", "config ENDPOINT is required"),
|
||||
],
|
||||
)
|
||||
def test_tablestore_config_validation(tablestore_module, field, value, message):
|
||||
values = _config(tablestore_module).model_dump()
|
||||
values[field] = value
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
tablestore_module.TableStoreConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_and_basic_delegation(tablestore_module):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
assert vector.get_type() == tablestore_module.VectorType.TABLESTORE
|
||||
assert vector._table_name == "collection_1"
|
||||
assert vector._index_name == "collection_1_idx"
|
||||
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]])
|
||||
|
||||
vector.create_collection([[0.1, 0.2]])
|
||||
assert vector._create_collection.call_count == 2
|
||||
|
||||
|
||||
def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
|
||||
# get_by_ids
|
||||
ok_item = SimpleNamespace(
|
||||
is_ok=True,
|
||||
row=SimpleNamespace(
|
||||
attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)]
|
||||
),
|
||||
)
|
||||
fail_item = SimpleNamespace(is_ok=False, row=None)
|
||||
batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item])
|
||||
vector._tablestore_client.batch_get_row.return_value = batch_resp
|
||||
docs = vector.get_by_ids(["id-1"])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "text-1"
|
||||
|
||||
# text_exists
|
||||
vector._tablestore_client.get_row.return_value = (None, object(), None)
|
||||
assert vector.text_exists("id-1") is True
|
||||
vector._tablestore_client.get_row.return_value = (None, None, None)
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
# delete wrappers
|
||||
vector._delete_row = MagicMock()
|
||||
vector.delete_by_ids([])
|
||||
vector._delete_row.assert_not_called()
|
||||
vector.delete_by_ids(["id-1", "id-2"])
|
||||
assert vector._delete_row.call_count == 2
|
||||
|
||||
vector._search_by_metadata = MagicMock(return_value=["id-a"])
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"]
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("document_id", "doc-1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-a"])
|
||||
|
||||
vector._search_by_vector = MagicMock(return_value=["vec-doc"])
|
||||
vector._search_by_full_text = MagicMock(return_value=["fts-doc"])
|
||||
assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"]
|
||||
assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"]
|
||||
|
||||
vector._delete_table_if_exist = MagicMock()
|
||||
vector.delete()
|
||||
vector._delete_table_if_exist.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock())
|
||||
|
||||
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_table_if_not_exist = MagicMock()
|
||||
vector._create_search_index_if_not_exist = MagicMock()
|
||||
vector._create_collection(3)
|
||||
vector._create_table_if_not_exist.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._create_collection(3)
|
||||
vector._create_table_if_not_exist.assert_called_once()
|
||||
vector._create_search_index_if_not_exist.assert_called_once_with(3)
|
||||
tablestore_module.redis_client.set.assert_called_once()
|
||||
|
||||
vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module))
|
||||
vector._tablestore_client.list_table.return_value = ["collection_2"]
|
||||
assert vector._create_table_if_not_exist() is None
|
||||
vector._tablestore_client.list_table.return_value = []
|
||||
vector._create_table_if_not_exist()
|
||||
vector._tablestore_client.create_table.assert_called_once()
|
||||
|
||||
vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")]
|
||||
assert vector._create_search_index_if_not_exist(3) is None
|
||||
vector._tablestore_client.list_search_index.return_value = []
|
||||
vector._create_search_index_if_not_exist(3)
|
||||
vector._tablestore_client.create_search_index.assert_called_once()
|
||||
|
||||
vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")]
|
||||
vector._delete_table_if_exist()
|
||||
assert vector._tablestore_client.delete_search_index.call_count == 2
|
||||
vector._tablestore_client.delete_table.assert_called_once_with("collection_2")
|
||||
|
||||
vector._delete_search_index()
|
||||
vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx")
|
||||
|
||||
|
||||
def test_write_row_and_search_helpers(tablestore_module):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
|
||||
vector._write_row(
|
||||
"id-1",
|
||||
{
|
||||
"page_content": "hello",
|
||||
"vector": [0.1, 0.2],
|
||||
"metadata": {"doc_id": "d-1", "document_id": "doc-1"},
|
||||
},
|
||||
)
|
||||
put_row_call = vector._tablestore_client.put_row.call_args
|
||||
assert put_row_call.args[0] == "collection_1"
|
||||
attrs = put_row_call.args[1].attribute_columns
|
||||
assert any(item[0] == "metadata_tags" for item in attrs)
|
||||
|
||||
vector._delete_row("id-1")
|
||||
vector._tablestore_client.delete_row.assert_called_once()
|
||||
|
||||
# metadata search pagination
|
||||
first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next")
|
||||
second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"")
|
||||
vector._tablestore_client.search.side_effect = [first_page, second_page]
|
||||
ids = vector._search_by_metadata("document_id", "doc-1")
|
||||
assert ids == ["row-1", "row-2"]
|
||||
vector._tablestore_client.search.side_effect = None
|
||||
|
||||
# vector search
|
||||
hit1 = SimpleNamespace(
|
||||
score=0.9,
|
||||
row=(
|
||||
None,
|
||||
[("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))],
|
||||
),
|
||||
)
|
||||
hit2 = SimpleNamespace(
|
||||
score=0.2,
|
||||
row=(
|
||||
None,
|
||||
[("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))],
|
||||
),
|
||||
)
|
||||
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2])
|
||||
docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0)
|
||||
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0
|
||||
|
||||
# full text search with and without normalized score filter
|
||||
vector._normalize_full_text_bm25_score = True
|
||||
hit3 = SimpleNamespace(
|
||||
score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))])
|
||||
)
|
||||
hit4 = SimpleNamespace(
|
||||
score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))])
|
||||
)
|
||||
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4])
|
||||
docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2)
|
||||
assert len(docs) == 1
|
||||
assert "score" in docs[0].metadata
|
||||
|
||||
vector._normalize_full_text_bm25_score = False
|
||||
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3])
|
||||
docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0)
|
||||
assert len(docs) == 1
|
||||
assert "score" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
|
||||
factory = tablestore_module.TableStoreVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint")
|
||||
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance")
|
||||
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak")
|
||||
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk")
|
||||
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True)
|
||||
|
||||
with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,309 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_tencent_modules():
|
||||
tcvdb_text = types.ModuleType("tcvdb_text")
|
||||
tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder")
|
||||
tcvectordb = types.ModuleType("tcvectordb")
|
||||
tcvectordb_model = types.ModuleType("tcvectordb.model")
|
||||
tcvectordb_document = types.ModuleType("tcvectordb.model.document")
|
||||
tcvectordb_index = types.ModuleType("tcvectordb.model.index")
|
||||
tcvectordb_enum = types.ModuleType("tcvectordb.model.enum")
|
||||
|
||||
class _BM25Encoder:
|
||||
def encode_texts(self, text):
|
||||
return {"encoded_text": text}
|
||||
|
||||
def encode_queries(self, query):
|
||||
return {"encoded_query": query}
|
||||
|
||||
@classmethod
|
||||
def default(cls, _lang):
|
||||
return cls()
|
||||
|
||||
class VectorDBError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class RPCVectorDBClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.create_database_if_not_exists = MagicMock()
|
||||
self.exists_collection = MagicMock(return_value=False)
|
||||
self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[]))
|
||||
self.create_collection = MagicMock()
|
||||
self.upsert = MagicMock()
|
||||
self.query = MagicMock(return_value=[])
|
||||
self.delete = MagicMock()
|
||||
self.search = MagicMock(return_value=[])
|
||||
self.hybrid_search = MagicMock(return_value=[])
|
||||
self.drop_collection = MagicMock()
|
||||
|
||||
class _Document:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
class _HNSWSearchParams:
|
||||
def __init__(self, ef):
|
||||
self.ef = ef
|
||||
|
||||
class _AnnSearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _KeywordSearch:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _WeightedRerank:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _Filter:
|
||||
@staticmethod
|
||||
def in_(field, values):
|
||||
return ("in", field, values)
|
||||
|
||||
def __init__(self, condition):
|
||||
self.condition = condition
|
||||
|
||||
_Filter.In = staticmethod(_Filter.in_)
|
||||
|
||||
class _HNSWParams:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _FilterIndex:
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
|
||||
class _VectorIndex:
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
|
||||
class _SparseIndex:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
tcvectordb_enum.IndexType = SimpleNamespace(
|
||||
__members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"},
|
||||
PRIMARY_KEY="PRIMARY_KEY",
|
||||
FILTER="FILTER",
|
||||
SPARSE_INVERTED="SPARSE",
|
||||
)
|
||||
tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP")
|
||||
tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector")
|
||||
|
||||
tcvectordb_document.Document = _Document
|
||||
tcvectordb_document.HNSWSearchParams = _HNSWSearchParams
|
||||
tcvectordb_document.AnnSearch = _AnnSearch
|
||||
tcvectordb_document.Filter = _Filter
|
||||
tcvectordb_document.KeywordSearch = _KeywordSearch
|
||||
tcvectordb_document.WeightedRerank = _WeightedRerank
|
||||
|
||||
tcvectordb_index.HNSWParams = _HNSWParams
|
||||
tcvectordb_index.FilterIndex = _FilterIndex
|
||||
tcvectordb_index.VectorIndex = _VectorIndex
|
||||
tcvectordb_index.SparseIndex = _SparseIndex
|
||||
|
||||
tcvdb_text_encoder.BM25Encoder = _BM25Encoder
|
||||
|
||||
tcvectordb_model.document = tcvectordb_document
|
||||
tcvectordb_model.enum = tcvectordb_enum
|
||||
tcvectordb_model.index = tcvectordb_index
|
||||
|
||||
tcvectordb.RPCVectorDBClient = RPCVectorDBClient
|
||||
tcvectordb.VectorDBException = VectorDBError
|
||||
|
||||
return {
|
||||
"tcvdb_text": tcvdb_text,
|
||||
"tcvdb_text.encoder": tcvdb_text_encoder,
|
||||
"tcvectordb": tcvectordb,
|
||||
"tcvectordb.model": tcvectordb_model,
|
||||
"tcvectordb.model.document": tcvectordb_document,
|
||||
"tcvectordb.model.index": tcvectordb_index,
|
||||
"tcvectordb.model.enum": tcvectordb_enum,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_module(monkeypatch):
|
||||
for name, module in _build_fake_tencent_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.tencent.tencent_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module, **overrides):
|
||||
values = {
|
||||
"url": "http://vdb.local",
|
||||
"api_key": "api-key",
|
||||
"timeout": 30,
|
||||
"username": "user",
|
||||
"database": "db",
|
||||
"index_type": "HNSW",
|
||||
"metric_type": "IP",
|
||||
"shard": 1,
|
||||
"replicas": 2,
|
||||
"max_upsert_batch_size": 2,
|
||||
"enable_hybrid_search": False,
|
||||
}
|
||||
values.update(overrides)
|
||||
return module.TencentConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_config_and_init_paths(tencent_module):
|
||||
config = _config(tencent_module)
|
||||
assert config.to_tencent_params()["url"] == "http://vdb.local"
|
||||
|
||||
vector = tencent_module.TencentVector("collection_1", config)
|
||||
assert vector.get_type() == tencent_module.VectorType.TENCENT
|
||||
assert vector._client.kwargs["key"] == "api-key"
|
||||
|
||||
vector._client.exists_collection.return_value = True
|
||||
vector._client.describe_collection.return_value = SimpleNamespace(
|
||||
indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)]
|
||||
)
|
||||
vector._client_config.enable_hybrid_search = True
|
||||
vector._load_collection()
|
||||
assert vector._enable_hybrid_search is True
|
||||
assert vector._dimension == 768
|
||||
|
||||
vector._client.describe_collection.return_value = SimpleNamespace(
|
||||
indexes=[SimpleNamespace(name="vector", dimension=512)]
|
||||
)
|
||||
vector._load_collection()
|
||||
assert vector._enable_hybrid_search is False
|
||||
|
||||
|
||||
def test_create_collection_branches(tencent_module, monkeypatch):
|
||||
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock())
|
||||
|
||||
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(3)
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._client.exists_collection.return_value = True
|
||||
vector._create_collection(3)
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
vector._client.exists_collection.return_value = False
|
||||
vector._client_config.index_type = "UNKNOWN"
|
||||
with pytest.raises(ValueError, match="unsupported index_type"):
|
||||
vector._create_collection(3)
|
||||
|
||||
vector._client_config.index_type = "HNSW"
|
||||
vector._client_config.metric_type = "UNKNOWN"
|
||||
with pytest.raises(ValueError, match="unsupported metric_type"):
|
||||
vector._create_collection(3)
|
||||
|
||||
vector._client_config.metric_type = "IP"
|
||||
vector._client.create_collection.side_effect = [
|
||||
tencent_module.VectorDBException("fieldType:json unsupported"),
|
||||
None,
|
||||
]
|
||||
vector._enable_hybrid_search = True
|
||||
vector._create_collection(3)
|
||||
assert vector._client.create_collection.call_count == 2
|
||||
tencent_module.redis_client.set.assert_called_once()
|
||||
vector._client.create_collection.side_effect = None
|
||||
|
||||
|
||||
def test_create_add_delete_and_search_behaviour(tencent_module):
|
||||
vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True))
|
||||
vector._create_collection = MagicMock()
|
||||
docs = [
|
||||
Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}),
|
||||
Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}),
|
||||
Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}),
|
||||
]
|
||||
embeddings = [[0.1], [0.2], [0.3]]
|
||||
vector.create(docs, embeddings)
|
||||
vector._create_collection.assert_called_once_with(1)
|
||||
|
||||
vector._client.upsert.reset_mock()
|
||||
vector.add_texts(docs, embeddings)
|
||||
assert vector._client.upsert.call_count == 2
|
||||
first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"]
|
||||
assert "sparse_vector" in first_docs[0].__dict__
|
||||
|
||||
vector._client.query.return_value = [{"id": "a"}]
|
||||
assert vector.text_exists("a") is True
|
||||
vector._client.query.return_value = []
|
||||
assert vector.text_exists("a") is False
|
||||
|
||||
vector.delete_by_ids([])
|
||||
vector._client.delete.assert_not_called()
|
||||
vector.delete_by_ids(["a", "b", "c"])
|
||||
assert vector._client.delete.call_count == 2
|
||||
vector.delete_by_metadata_field("document_id", "doc-a")
|
||||
assert vector._client.delete.call_count >= 3
|
||||
|
||||
vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]]
|
||||
vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
|
||||
assert len(vec_docs) == 1
|
||||
assert vec_docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
vector._enable_hybrid_search = False
|
||||
assert vector.search_by_full_text("query") == []
|
||||
vector._enable_hybrid_search = True
|
||||
vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]]
|
||||
fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
|
||||
assert len(fts_docs) == 1
|
||||
|
||||
# _get_search_res handles old string metadata format
|
||||
compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5)
|
||||
assert len(compat_docs) == 1
|
||||
assert compat_docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
|
||||
vector._has_collection = MagicMock(return_value=True)
|
||||
vector.delete()
|
||||
vector._client.drop_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
|
||||
factory = tencent_module.TencentVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local")
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key")
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30)
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user")
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db")
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1)
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2)
|
||||
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True)
|
||||
|
||||
with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,88 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class _DummyVector(BaseVector):
|
||||
def __init__(self, collection_name: str, existing_ids: set[str] | None = None):
|
||||
super().__init__(collection_name)
|
||||
self._existing_ids = existing_ids or set()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
return None
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
return None
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return id in self._existing_ids
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
return None
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs):
|
||||
return []
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs):
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("base_method", "args"),
|
||||
[
|
||||
(BaseVector.get_type, ()),
|
||||
(BaseVector.create, ([], [])),
|
||||
(BaseVector.add_texts, ([], [])),
|
||||
(BaseVector.text_exists, ("doc-1",)),
|
||||
(BaseVector.delete_by_ids, ([],)),
|
||||
(BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")),
|
||||
(BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")),
|
||||
(BaseVector.search_by_vector, ([0.1],)),
|
||||
(BaseVector.search_by_full_text, ("query",)),
|
||||
(BaseVector.delete, ()),
|
||||
],
|
||||
)
|
||||
def test_base_vector_default_methods_raise_not_implemented(base_method, args):
|
||||
vector = _DummyVector("collection_1")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
base_method(vector, *args)
|
||||
|
||||
|
||||
def test_filter_duplicate_texts_removes_existing_docs():
|
||||
vector = _DummyVector("collection_1", existing_ids={"dup"})
|
||||
docs = [
|
||||
SimpleNamespace(page_content="keep-no-meta", metadata=None),
|
||||
Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}),
|
||||
Document(page_content="remove-dup", metadata={"doc_id": "dup"}),
|
||||
Document(page_content="keep-unique", metadata={"doc_id": "unique"}),
|
||||
]
|
||||
|
||||
filtered = vector._filter_duplicate_texts(docs)
|
||||
|
||||
assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"]
|
||||
|
||||
|
||||
def test_get_uuids_and_collection_name_property():
|
||||
vector = _DummyVector("collection_1")
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
SimpleNamespace(page_content="b", metadata=None),
|
||||
Document(page_content="c", metadata={"document_id": "d-1"}),
|
||||
Document(page_content="d", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
|
||||
assert vector._get_uuids(docs) == ["id-1", "id-2"]
|
||||
assert vector.collection_name == "collection_1"
|
||||
@ -0,0 +1,434 @@
|
||||
import base64
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str):
|
||||
fake_module = types.ModuleType(module_path)
|
||||
fake_cls = type(class_name, (), {})
|
||||
setattr(fake_module, class_name, fake_cls)
|
||||
monkeypatch.setitem(sys.modules, module_path, fake_module)
|
||||
return fake_cls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_factory_module():
|
||||
import importlib
|
||||
|
||||
import core.rag.datasource.vdb.vector_factory as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def test_gen_index_struct_dict(vector_factory_module):
|
||||
result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict(
|
||||
vector_factory_module.VectorType.WEAVIATE,
|
||||
"collection_1",
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": vector_factory_module.VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": "collection_1"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("vector_type", "module_path", "class_name"),
|
||||
[
|
||||
("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"),
|
||||
("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"),
|
||||
(
|
||||
"ALIBABACLOUD_MYSQL",
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector",
|
||||
"AlibabaCloudMySQLVectorFactory",
|
||||
),
|
||||
("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"),
|
||||
("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"),
|
||||
("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"),
|
||||
("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"),
|
||||
("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"),
|
||||
("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"),
|
||||
(
|
||||
"ELASTICSEARCH",
|
||||
"core.rag.datasource.vdb.elasticsearch.elasticsearch_vector",
|
||||
"ElasticSearchVectorFactory",
|
||||
),
|
||||
(
|
||||
"ELASTICSEARCH_JA",
|
||||
"core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector",
|
||||
"ElasticSearchJaVectorFactory",
|
||||
),
|
||||
("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"),
|
||||
("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"),
|
||||
("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"),
|
||||
("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"),
|
||||
(
|
||||
"OPENSEARCH",
|
||||
"core.rag.datasource.vdb.opensearch.opensearch_vector",
|
||||
"OpenSearchVectorFactory",
|
||||
),
|
||||
("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"),
|
||||
("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"),
|
||||
("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"),
|
||||
("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"),
|
||||
("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"),
|
||||
(
|
||||
"TIDB_ON_QDRANT",
|
||||
"core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector",
|
||||
"TidbOnQdrantVectorFactory",
|
||||
),
|
||||
("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"),
|
||||
("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
|
||||
("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
|
||||
("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"),
|
||||
("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"),
|
||||
(
|
||||
"HUAWEI_CLOUD",
|
||||
"core.rag.datasource.vdb.huawei.huawei_cloud_vector",
|
||||
"HuaweiCloudVectorFactory",
|
||||
),
|
||||
("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"),
|
||||
("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"),
|
||||
("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"),
|
||||
],
|
||||
)
|
||||
def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name):
|
||||
expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name)
|
||||
|
||||
result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type))
|
||||
|
||||
assert result_cls is expected_cls
|
||||
|
||||
|
||||
def test_get_vector_factory_unsupported(vector_factory_module):
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
vector_factory_module.Vector.get_vector_factory("unknown")
|
||||
|
||||
|
||||
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
with (
|
||||
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
|
||||
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
|
||||
):
|
||||
default_vector = vector_factory_module.Vector(dataset)
|
||||
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
|
||||
|
||||
assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
assert custom_vector._attributes == ["doc_id"]
|
||||
assert default_vector._embeddings == "embeddings"
|
||||
assert default_vector._vector_processor == "processor"
|
||||
|
||||
|
||||
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
|
||||
calls = {"vector_type": None, "init_args": None}
|
||||
|
||||
class _Factory:
|
||||
def init_vector(self, dataset, attributes, embeddings):
|
||||
calls["init_args"] = (dataset, attributes, embeddings)
|
||||
return "vector-processor"
|
||||
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module.Vector,
|
||||
"get_vector_factory",
|
||||
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
|
||||
)
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._dataset = SimpleNamespace(
|
||||
index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1"
|
||||
)
|
||||
vector._attributes = ["doc_id"]
|
||||
vector._embeddings = "embeddings"
|
||||
|
||||
result = vector._init_vector()
|
||||
|
||||
assert result == "vector-processor"
|
||||
assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH
|
||||
assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings")
|
||||
|
||||
|
||||
def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch):
|
||||
class _Expr:
|
||||
def __eq__(self, _other):
|
||||
return "expr"
|
||||
|
||||
calls = {"vector_type": None}
|
||||
|
||||
class _Factory:
|
||||
def init_vector(self, dataset, attributes, embeddings):
|
||||
return "vector-processor"
|
||||
|
||||
monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr()))
|
||||
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module,
|
||||
"db",
|
||||
SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))),
|
||||
)
|
||||
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA)
|
||||
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True)
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module.Vector,
|
||||
"get_vector_factory",
|
||||
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
|
||||
)
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
|
||||
vector._attributes = ["doc_id"]
|
||||
vector._embeddings = "embeddings"
|
||||
|
||||
result = vector._init_vector()
|
||||
|
||||
assert result == "vector-processor"
|
||||
assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT
|
||||
|
||||
|
||||
def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch):
|
||||
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None)
|
||||
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False)
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
|
||||
vector._attributes = []
|
||||
vector._embeddings = "embeddings"
|
||||
|
||||
with pytest.raises(ValueError, match="Vector store must be specified"):
|
||||
vector._init_vector()
|
||||
|
||||
|
||||
def test_create_batches_texts_and_skips_empty_input(vector_factory_module):
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)]
|
||||
vector._embeddings.embed_documents.side_effect = [
|
||||
[[0.1] for _ in range(1000)],
|
||||
[[0.2]],
|
||||
]
|
||||
|
||||
vector.create(texts=docs, trace_id="trace-1")
|
||||
|
||||
assert vector._embeddings.embed_documents.call_count == 2
|
||||
assert vector._vector_processor.create.call_count == 2
|
||||
assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1"
|
||||
|
||||
vector._embeddings.embed_documents.reset_mock()
|
||||
vector._vector_processor.create.reset_mock()
|
||||
vector.create(texts=None)
|
||||
vector._embeddings.embed_documents.assert_not_called()
|
||||
vector._vector_processor.create.assert_not_called()
|
||||
|
||||
|
||||
def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch):
|
||||
class _Field:
|
||||
def in_(self, value):
|
||||
return value
|
||||
|
||||
def __eq__(self, value):
|
||||
return value
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]]
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module,
|
||||
"db",
|
||||
SimpleNamespace(
|
||||
session=SimpleNamespace(
|
||||
scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")])
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc"))
|
||||
|
||||
docs = [
|
||||
Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}),
|
||||
Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}),
|
||||
]
|
||||
|
||||
vector.create_multimodal(file_documents=docs, request_id="r-1")
|
||||
|
||||
file_base64 = base64.b64encode(b"abc").decode()
|
||||
vector._embeddings.embed_multimodal_documents.assert_called_once_with(
|
||||
[{"content": file_base64, "content_type": "image", "file_id": "f-1"}]
|
||||
)
|
||||
vector._vector_processor.create.assert_called_once_with(
|
||||
texts=[docs[0]],
|
||||
embeddings=[[0.1, 0.2]],
|
||||
request_id="r-1",
|
||||
)
|
||||
|
||||
vector._embeddings.embed_multimodal_documents.reset_mock()
|
||||
vector._vector_processor.create.reset_mock()
|
||||
vector.create_multimodal(file_documents=None)
|
||||
vector._embeddings.embed_multimodal_documents.assert_not_called()
|
||||
vector._vector_processor.create.assert_not_called()
|
||||
|
||||
|
||||
def test_add_texts_with_optional_duplicate_check(vector_factory_module):
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
vector._filter_duplicate_texts = MagicMock()
|
||||
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-2"}),
|
||||
]
|
||||
vector._filter_duplicate_texts.return_value = [docs[0]]
|
||||
vector._embeddings.embed_documents.return_value = [[0.1]]
|
||||
|
||||
vector.add_texts(docs, duplicate_check=True, flag=True)
|
||||
|
||||
vector._filter_duplicate_texts.assert_called_once_with(docs)
|
||||
vector._vector_processor.create.assert_called_once_with(
|
||||
texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True
|
||||
)
|
||||
|
||||
vector._filter_duplicate_texts.reset_mock()
|
||||
vector._vector_processor.create.reset_mock()
|
||||
vector._embeddings.embed_documents.return_value = [[0.2], [0.3]]
|
||||
|
||||
vector.add_texts(docs, duplicate_check=False)
|
||||
|
||||
vector._filter_duplicate_texts.assert_not_called()
|
||||
vector._vector_processor.create.assert_called_once()
|
||||
|
||||
|
||||
def test_vector_delegation_methods(vector_factory_module):
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._embeddings.embed_query.return_value = [0.1, 0.2]
|
||||
vector._vector_processor = MagicMock()
|
||||
vector._vector_processor.text_exists.return_value = True
|
||||
vector._vector_processor.search_by_vector.return_value = ["vector-doc"]
|
||||
vector._vector_processor.search_by_full_text.return_value = ["text-doc"]
|
||||
|
||||
assert vector.text_exists("doc-1") is True
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
vector.delete_by_metadata_field("doc_id", "doc-1")
|
||||
assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"]
|
||||
assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"]
|
||||
|
||||
vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"])
|
||||
vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1")
|
||||
|
||||
|
||||
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
|
||||
class _Field:
|
||||
def __eq__(self, value):
|
||||
return value
|
||||
|
||||
upload_query = MagicMock()
|
||||
upload_query.where.return_value = upload_query
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
|
||||
)
|
||||
|
||||
upload_query.first.return_value = None
|
||||
assert vector.search_by_file("file-1") == []
|
||||
|
||||
upload_query.first.return_value = SimpleNamespace(key="blob-key")
|
||||
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
|
||||
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
|
||||
vector._vector_processor.search_by_vector.return_value = ["hit"]
|
||||
|
||||
result = vector.search_by_file("file-2", top_k=2)
|
||||
|
||||
assert result == ["hit"]
|
||||
payload = vector._embeddings.embed_multimodal_query.call_args.args[0]
|
||||
assert payload["content_type"] == vector_factory_module.DocType.IMAGE
|
||||
assert payload["file_id"] == "file-2"
|
||||
|
||||
|
||||
def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch):
|
||||
delete_mock = MagicMock()
|
||||
redis_delete = MagicMock()
|
||||
monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete)
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1")
|
||||
|
||||
vector.delete()
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
redis_delete.assert_called_once_with("vector_indexing_collection_1")
|
||||
|
||||
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="")
|
||||
redis_delete.reset_mock()
|
||||
vector.delete()
|
||||
redis_delete.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch):
|
||||
model_manager = MagicMock()
|
||||
model_manager.get_model_instance.return_value = "model-instance"
|
||||
|
||||
monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager))
|
||||
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding"))
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._dataset = SimpleNamespace(
|
||||
tenant_id="tenant-1",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
result = vector._get_embeddings()
|
||||
|
||||
assert result == "cached-embedding"
|
||||
model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model_type=vector_factory_module.ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
|
||||
def test_filter_duplicate_texts_and_getattr(vector_factory_module):
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup")
|
||||
|
||||
docs = [
|
||||
SimpleNamespace(page_content="no-meta", metadata=None),
|
||||
Document(page_content="empty-doc-id", metadata={"doc_id": ""}),
|
||||
Document(page_content="duplicate", metadata={"doc_id": "dup"}),
|
||||
Document(page_content="unique", metadata={"doc_id": "ok"}),
|
||||
]
|
||||
|
||||
filtered = vector._filter_duplicate_texts(docs)
|
||||
assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"]
|
||||
|
||||
class _Processor:
|
||||
def ping(self):
|
||||
return "pong"
|
||||
|
||||
vector._vector_processor = _Processor()
|
||||
assert vector.ping() == "pong"
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = vector.unknown_method
|
||||
|
||||
vector._vector_processor = None
|
||||
with pytest.raises(AttributeError, match="vector_processor"):
|
||||
_ = vector.another_missing
|
||||
@ -0,0 +1,443 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_module():
|
||||
import core.rag.datasource.vdb.tidb_vector.tidb_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(tidb_module):
|
||||
return tidb_module.TiDBVectorConfig(
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="secret",
|
||||
database="dify",
|
||||
program_name="dify-app",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("host", "", "config TIDB_VECTOR_HOST is required"),
|
||||
("port", 0, "config TIDB_VECTOR_PORT is required"),
|
||||
("user", "", "config TIDB_VECTOR_USER is required"),
|
||||
("database", "", "config TIDB_VECTOR_DATABASE is required"),
|
||||
("program_name", "", "config APPLICATION_NAME is required"),
|
||||
],
|
||||
)
|
||||
def test_tidb_config_validation(tidb_module, field, value, message):
|
||||
values = _config(tidb_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
tidb_module.TiDBVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
|
||||
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
|
||||
|
||||
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
|
||||
|
||||
assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR
|
||||
assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify")
|
||||
assert vector._dimension == 1536
|
||||
assert vector._get_distance_func() == "VEC_L2_DISTANCE"
|
||||
|
||||
vector._distance_func = "cosine"
|
||||
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
|
||||
|
||||
vector._distance_func = "other"
|
||||
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
|
||||
|
||||
|
||||
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
|
||||
fake_tidb_vector = types.ModuleType("tidb_vector")
|
||||
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
|
||||
|
||||
class _VectorType:
|
||||
def __init__(self, dim):
|
||||
self.dim = dim
|
||||
|
||||
fake_tidb_sqlalchemy.VectorType = _VectorType
|
||||
|
||||
monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector)
|
||||
monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy)
|
||||
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock()))
|
||||
monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
|
||||
monkeypatch.setattr(
|
||||
tidb_module,
|
||||
"Table",
|
||||
lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns),
|
||||
)
|
||||
|
||||
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module))
|
||||
table = vector._table(3)
|
||||
|
||||
assert table.name == "collection_1"
|
||||
column_names = [column.args[0] for column in table.columns]
|
||||
assert tidb_module.Field.PRIMARY_KEY in column_names
|
||||
assert tidb_module.Field.VECTOR in column_names
|
||||
assert tidb_module.Field.TEXT_KEY in column_names
|
||||
|
||||
|
||||
def test_create_calls_collection_and_add_texts(tidb_module):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
assert vector._dimension == 2
|
||||
|
||||
|
||||
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1))
|
||||
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
|
||||
tidb_module.Session = MagicMock()
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
tidb_module.Session.assert_not_called()
|
||||
tidb_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector._distance_func = "l2"
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
session.begin.assert_called_once()
|
||||
sql = str(session.execute.call_args.args[0])
|
||||
assert "VECTOR<FLOAT>(3)" in sql
|
||||
assert "VEC_L2_DISTANCE" in sql
|
||||
session.commit.assert_called_once()
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
|
||||
class _InsertStmt:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
|
||||
def values(self, rows):
|
||||
return {"table": self.table, "rows": rows}
|
||||
|
||||
monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table))
|
||||
|
||||
conn = MagicMock()
|
||||
transaction = MagicMock()
|
||||
transaction.__enter__.return_value = None
|
||||
transaction.__exit__.return_value = None
|
||||
conn.begin.return_value = transaction
|
||||
|
||||
connection_ctx = MagicMock()
|
||||
connection_ctx.__enter__.return_value = conn
|
||||
connection_ctx.__exit__.return_value = None
|
||||
|
||||
engine = MagicMock()
|
||||
engine.connect.return_value = connection_ctx
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._engine = engine
|
||||
vector._table = MagicMock(return_value="table")
|
||||
|
||||
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)]
|
||||
embeddings = [[float(i)] for i in range(501)]
|
||||
|
||||
ids = vector.add_texts(docs, embeddings)
|
||||
|
||||
assert ids[0] == "id-0"
|
||||
assert len(ids) == 501
|
||||
assert conn.execute.call_count == 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector_with_session(tidb_module, monkeypatch):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
return vector, session, tidb_module
|
||||
|
||||
|
||||
# 1. search_by_full_text returns empty
|
||||
def test_search_by_full_text_returns_empty(tidb_vector_with_session):
|
||||
vector, _, _ = tidb_vector_with_session
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
# 2. text_exists returns True when ids found
|
||||
def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session):
|
||||
vector, _, _ = tidb_vector_with_session
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
assert vector.text_exists("doc-1") is True
|
||||
|
||||
|
||||
# 3. text_exists returns False when no ids
|
||||
def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session):
|
||||
vector, _, _ = tidb_vector_with_session
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=None)
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
|
||||
# 4. delete_by_ids delegates to _delete_by_ids when ids found
|
||||
def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session):
|
||||
vector, session, tidb_module = tidb_vector_with_session
|
||||
session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)]
|
||||
vector._delete_by_ids = MagicMock()
|
||||
# Use real get_ids_by_metadata_field
|
||||
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
|
||||
vector, tidb_module.TiDBVector
|
||||
)
|
||||
vector.delete_by_ids(["doc-a", "doc-b"])
|
||||
vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"])
|
||||
|
||||
|
||||
# 5. delete_by_ids skips when no ids found
|
||||
def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session):
|
||||
vector, session, tidb_module = tidb_vector_with_session
|
||||
session.execute.return_value.fetchall.return_value = []
|
||||
vector._delete_by_ids = MagicMock()
|
||||
# Use real get_ids_by_metadata_field
|
||||
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
|
||||
vector, tidb_module.TiDBVector
|
||||
)
|
||||
vector.delete_by_ids(["doc-c"])
|
||||
vector._delete_by_ids.assert_not_called()
|
||||
|
||||
|
||||
# 6. get_ids_by_metadata_field returns ids and returns None
|
||||
def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session):
|
||||
vector, session, tidb_module = tidb_vector_with_session
|
||||
# Returns ids
|
||||
session.execute.return_value.fetchall.return_value = [("id-1",)]
|
||||
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"]
|
||||
# Returns None
|
||||
session.execute.return_value.fetchall.return_value = []
|
||||
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None
|
||||
|
||||
|
||||
# 1. _delete_by_ids raises on None
|
||||
def test__delete_by_ids_raises_on_none(tidb_module):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
with pytest.raises(ValueError, match="No ids provided"):
|
||||
vector._delete_by_ids(None)
|
||||
|
||||
|
||||
# 2. _delete_by_ids returns True and calls execute
|
||||
def test__delete_by_ids_returns_true_and_calls_execute(tidb_module):
|
||||
class _IDColumn:
|
||||
def in_(self, ids):
|
||||
return ids
|
||||
|
||||
class _Delete:
|
||||
def where(self, condition):
|
||||
return condition
|
||||
|
||||
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
|
||||
conn = MagicMock()
|
||||
tx = MagicMock()
|
||||
tx.__enter__.return_value = None
|
||||
tx.__exit__.return_value = None
|
||||
conn.begin.return_value = tx
|
||||
conn_ctx = MagicMock()
|
||||
conn_ctx.__enter__.return_value = conn
|
||||
conn_ctx.__exit__.return_value = None
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._dimension = 2
|
||||
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
|
||||
vector._table = MagicMock(return_value=table)
|
||||
assert vector._delete_by_ids(["id-1"]) is True
|
||||
conn.execute.assert_called_once()
|
||||
|
||||
|
||||
# 3. _delete_by_ids returns False on RuntimeError
|
||||
def test__delete_by_ids_returns_false_on_runtime_error(tidb_module):
|
||||
class _IDColumn:
|
||||
def in_(self, ids):
|
||||
return ids
|
||||
|
||||
class _Delete:
|
||||
def where(self, condition):
|
||||
return condition
|
||||
|
||||
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
|
||||
conn = MagicMock()
|
||||
tx = MagicMock()
|
||||
tx.__enter__.return_value = None
|
||||
tx.__exit__.return_value = None
|
||||
conn.begin.return_value = tx
|
||||
conn_ctx = MagicMock()
|
||||
conn_ctx.__enter__.return_value = conn
|
||||
conn_ctx.__exit__.return_value = None
|
||||
conn.execute.side_effect = RuntimeError("delete failed")
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._dimension = 2
|
||||
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
|
||||
vector._table = MagicMock(return_value=table)
|
||||
assert vector._delete_by_ids(["id-2"]) is False
|
||||
|
||||
|
||||
# 4. delete_by_metadata_field calls _delete_by_ids when ids found
|
||||
def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"])
|
||||
vector._delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("doc_id", "doc-3")
|
||||
vector._delete_by_ids.assert_called_once_with(["id-3"])
|
||||
|
||||
|
||||
# 5. delete_by_metadata_field does nothing when no ids
|
||||
def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
|
||||
vector._delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("doc_id", "doc-4")
|
||||
vector._delete_by_ids.assert_not_called()
|
||||
|
||||
|
||||
# Test search_by_vector filters and scores
|
||||
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = [
|
||||
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
|
||||
('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4),
|
||||
]
|
||||
session.commit = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector._distance_func = "cosine"
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
top_k=2,
|
||||
score_threshold=0.5,
|
||||
document_ids_filter=["d-1", "d-2"],
|
||||
)
|
||||
assert len(docs) == 2
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
assert docs[1].metadata["score"] == pytest.approx(0.6)
|
||||
sql = str(session.execute.call_args.args[0])
|
||||
params = session.execute.call_args.kwargs["params"]
|
||||
assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql
|
||||
assert params["distance"] == pytest.approx(0.5)
|
||||
assert params["top_k"] == 2
|
||||
session.commit.assert_not_called()
|
||||
|
||||
|
||||
# Test delete drops table
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
session.commit = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector.delete()
|
||||
drop_sql = str(session.execute.call_args.args[0])
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
factory = tidb_module.TiDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost")
|
||||
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000)
|
||||
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root")
|
||||
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret")
|
||||
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify")
|
||||
monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app")
|
||||
|
||||
with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,186 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_upstash_module():
|
||||
upstash_module = types.ModuleType("upstash_vector")
|
||||
|
||||
class Vector:
|
||||
def __init__(self, id, vector, metadata, data):
|
||||
self.id = id
|
||||
self.vector = vector
|
||||
self.metadata = metadata
|
||||
self.data = data
|
||||
|
||||
class Index:
|
||||
def __init__(self, url, token):
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.info = MagicMock(return_value=SimpleNamespace(dimension=8))
|
||||
self.upsert = MagicMock()
|
||||
self.query = MagicMock(return_value=[])
|
||||
self.delete = MagicMock()
|
||||
self.reset = MagicMock()
|
||||
|
||||
upstash_module.Vector = Vector
|
||||
upstash_module.Index = Index
|
||||
return upstash_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upstash_module(monkeypatch):
|
||||
# Remove patched modules if present
|
||||
for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]:
|
||||
if modname in sys.modules:
|
||||
monkeypatch.delitem(sys.modules, modname, raising=False)
|
||||
monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module())
|
||||
module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector")
|
||||
return module
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.UpstashVectorConfig(url="https://upstash.example", token="token-123")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "message"),
|
||||
[
|
||||
("url", "", "Upstash URL is required"),
|
||||
("token", "", "Upstash Token is required"),
|
||||
],
|
||||
)
|
||||
def test_upstash_config_validation(upstash_module, field, value, message):
|
||||
values = _config(upstash_module).model_dump()
|
||||
values[field] = value
|
||||
|
||||
with pytest.raises(ValidationError, match=message):
|
||||
upstash_module.UpstashVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
|
||||
assert vector.get_type() == upstash_module.VectorType.UPSTASH
|
||||
assert vector._table_name == "collection_1"
|
||||
assert vector._get_index_dimension() == 8
|
||||
|
||||
vector.index.info.return_value = SimpleNamespace(dimension=None)
|
||||
assert vector._get_index_dimension() == 1536
|
||||
|
||||
vector.index.info.return_value = None
|
||||
assert vector._get_index_dimension() == 1536
|
||||
|
||||
monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid")
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
|
||||
vector.add_texts(docs, [[0.1, 0.2]])
|
||||
|
||||
vector.index.upsert.assert_called_once()
|
||||
upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"]
|
||||
assert upsert_vectors[0].id == "generated-uuid"
|
||||
|
||||
|
||||
def test_create_text_exists_and_delete_by_ids(upstash_module):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
|
||||
vector.create(docs, [[0.1]])
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1]])
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
|
||||
assert vector.text_exists("doc-1") is True
|
||||
vector.get_ids_by_metadata_field.return_value = []
|
||||
assert vector.text_exists("doc-1") is False
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]])
|
||||
vector._delete_by_ids = MagicMock()
|
||||
vector.delete_by_ids(["doc-1", "doc-2", "doc-3"])
|
||||
vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"])
|
||||
|
||||
|
||||
def test_delete_helpers_and_search(upstash_module):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
|
||||
vector._delete_by_ids([])
|
||||
vector.index.delete.assert_not_called()
|
||||
vector._delete_by_ids(["a", "b"])
|
||||
vector.index.delete.assert_called_once_with(ids=["a", "b"])
|
||||
|
||||
vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")]
|
||||
ids = vector.get_ids_by_metadata_field("doc_id", "doc-1")
|
||||
assert ids == ["x-1", "x-2"]
|
||||
query_kwargs = vector.index.query.call_args.kwargs
|
||||
assert query_kwargs["top_k"] == 1000
|
||||
assert query_kwargs["filter"] == "doc_id = 'doc-1'"
|
||||
|
||||
vector._delete_by_ids = MagicMock()
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"])
|
||||
vector.delete_by_metadata_field("doc_id", "doc-1")
|
||||
vector._delete_by_ids.assert_called_once_with(["x-1"])
|
||||
|
||||
vector._delete_by_ids.reset_mock()
|
||||
vector.get_ids_by_metadata_field.return_value = []
|
||||
vector.delete_by_metadata_field("doc_id", "doc-2")
|
||||
vector._delete_by_ids.assert_not_called()
|
||||
|
||||
|
||||
def test_search_by_vector_filter_threshold_and_delete(upstash_module):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
vector.index.query.return_value = [
|
||||
SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9),
|
||||
SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3),
|
||||
SimpleNamespace(metadata=None, data="text-3", score=0.99),
|
||||
SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99),
|
||||
]
|
||||
|
||||
docs = vector.search_by_vector(
|
||||
[0.1, 0.2],
|
||||
top_k=3,
|
||||
score_threshold=0.5,
|
||||
document_ids_filter=["d-1", "d-2"],
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "text-1"
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
search_kwargs = vector.index.query.call_args.kwargs
|
||||
assert search_kwargs["top_k"] == 3
|
||||
assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')"
|
||||
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
vector.delete()
|
||||
vector.index.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
|
||||
factory = upstash_module.UpstashVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example")
|
||||
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123")
|
||||
|
||||
with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls:
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
@ -0,0 +1,310 @@
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from collections import UserDict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _build_fake_vikingdb_modules():
|
||||
volcengine = types.ModuleType("volcengine")
|
||||
volcengine.__path__ = []
|
||||
viking_db = types.ModuleType("volcengine.viking_db")
|
||||
|
||||
class Data(UserDict):
|
||||
def __init__(self, payload):
|
||||
super().__init__(payload)
|
||||
self.fields = payload
|
||||
|
||||
class DistanceType:
|
||||
L2 = "L2"
|
||||
|
||||
class IndexType:
|
||||
HNSW = "HNSW"
|
||||
|
||||
class QuantType:
|
||||
Float = "Float"
|
||||
|
||||
class FieldType:
|
||||
String = "string"
|
||||
Text = "text"
|
||||
Vector = "vector"
|
||||
|
||||
class Field:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class VectorIndexParams:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _Collection:
|
||||
def __init__(self):
|
||||
self.upsert_data = MagicMock()
|
||||
self.fetch_data = MagicMock(return_value=None)
|
||||
self.delete_data = MagicMock()
|
||||
|
||||
class _Index:
|
||||
def __init__(self):
|
||||
self.search = MagicMock(return_value=[])
|
||||
self.search_by_vector = MagicMock(return_value=[])
|
||||
|
||||
class VikingDBService:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.create_collection = MagicMock()
|
||||
self.create_index = MagicMock()
|
||||
self.drop_index = MagicMock()
|
||||
self.drop_collection = MagicMock()
|
||||
self._collection = _Collection()
|
||||
self._index = _Index()
|
||||
self.get_collection = MagicMock(return_value=self._collection)
|
||||
self.get_index = MagicMock(return_value=self._index)
|
||||
|
||||
viking_db.Data = Data
|
||||
viking_db.DistanceType = DistanceType
|
||||
viking_db.Field = Field
|
||||
viking_db.FieldType = FieldType
|
||||
viking_db.IndexType = IndexType
|
||||
viking_db.QuantType = QuantType
|
||||
viking_db.VectorIndexParams = VectorIndexParams
|
||||
viking_db.VikingDBService = VikingDBService
|
||||
|
||||
return {"volcengine": volcengine, "volcengine.viking_db": viking_db}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vikingdb_module(monkeypatch):
|
||||
for name, module in _build_fake_vikingdb_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def _config(module):
|
||||
return module.VikingDBConfig(
|
||||
access_key="ak",
|
||||
secret_key="sk",
|
||||
host="host",
|
||||
region="region",
|
||||
scheme="https",
|
||||
connection_timeout=10,
|
||||
socket_timeout=20,
|
||||
)
|
||||
|
||||
|
||||
def test_init_get_type_and_has_checks(vikingdb_module):
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
|
||||
assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB
|
||||
assert vector._index_name == "collection_1_idx"
|
||||
|
||||
assert vector._has_collection() is True
|
||||
assert vector._has_index() is True
|
||||
|
||||
vector._client.get_collection.side_effect = RuntimeError("missing")
|
||||
assert vector._has_collection() is False
|
||||
vector._client.get_collection.side_effect = None
|
||||
|
||||
vector._client.get_index.side_effect = RuntimeError("missing")
|
||||
assert vector._has_index() is False
|
||||
|
||||
|
||||
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock())
|
||||
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
|
||||
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1))
|
||||
vector._create_collection(3)
|
||||
vector._client.create_collection.assert_not_called()
|
||||
vector._client.create_index.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None))
|
||||
vector._has_collection = MagicMock(return_value=False)
|
||||
vector._has_index = MagicMock(return_value=False)
|
||||
vector._create_collection(4)
|
||||
|
||||
vector._client.create_collection.assert_called_once()
|
||||
vector._client.create_index.assert_called_once()
|
||||
vikingdb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_add_texts(vikingdb_module):
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
vector._create_collection = MagicMock()
|
||||
vector.add_texts = MagicMock()
|
||||
|
||||
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
|
||||
vector.create(docs, [[0.1, 0.2]])
|
||||
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}),
|
||||
Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}),
|
||||
]
|
||||
vector.add_texts(docs, [[0.1], [0.2]])
|
||||
|
||||
vector._client.get_collection.assert_called()
|
||||
upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0]
|
||||
assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a"
|
||||
assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2"
|
||||
|
||||
|
||||
def test_text_exists_and_delete_operations(vikingdb_module):
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
|
||||
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"})
|
||||
assert vector.text_exists("id-1") is True
|
||||
|
||||
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(
|
||||
fields={"message": "data does not exist"}
|
||||
)
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
vector._client.get_collection.return_value.fetch_data.return_value = None
|
||||
assert vector.text_exists("id-1") is False
|
||||
|
||||
vector.delete_by_ids(["id-1"])
|
||||
vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"])
|
||||
|
||||
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"])
|
||||
vector.delete_by_ids = MagicMock()
|
||||
vector.delete_by_metadata_field("doc_id", "doc-1")
|
||||
vector.delete_by_ids.assert_called_once_with(["id-2"])
|
||||
|
||||
|
||||
def test_get_ids_and_search_helpers(vikingdb_module):
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
|
||||
vector._client.get_index.return_value.search.return_value = []
|
||||
assert vector.get_ids_by_metadata_field("doc_id", "x") == []
|
||||
|
||||
vector._client.get_index.return_value.search.return_value = [
|
||||
SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}),
|
||||
SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}),
|
||||
SimpleNamespace(id="c", fields={}),
|
||||
]
|
||||
assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"]
|
||||
|
||||
empty_docs = vector._get_search_res([], score_threshold=0.1)
|
||||
assert empty_docs == []
|
||||
|
||||
results = [
|
||||
SimpleNamespace(
|
||||
id="a",
|
||||
score=0.3,
|
||||
fields={
|
||||
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a",
|
||||
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}),
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="b",
|
||||
score=0.9,
|
||||
fields={
|
||||
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b",
|
||||
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
docs = vector._get_search_res(results, score_threshold=0.2)
|
||||
assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"]
|
||||
|
||||
vector._client.get_index.return_value.search_by_vector.return_value = results
|
||||
filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"])
|
||||
assert len(filtered_docs) == 1
|
||||
assert filtered_docs[0].page_content == "doc-b"
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_delete_drops_index_and_collection_when_present(vikingdb_module):
|
||||
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
|
||||
vector._has_index = MagicMock(return_value=True)
|
||||
vector._has_collection = MagicMock(return_value=True)
|
||||
|
||||
vector.delete()
|
||||
|
||||
vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx")
|
||||
vector._client.drop_collection.assert_called_once_with("collection_1")
|
||||
|
||||
vector._client.drop_index.reset_mock()
|
||||
vector._client.drop_collection.reset_mock()
|
||||
vector._has_index.return_value = False
|
||||
vector._has_collection.return_value = False
|
||||
vector.delete()
|
||||
|
||||
vector._client.drop_index.assert_not_called()
|
||||
vector._client.drop_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
|
||||
index_struct=None,
|
||||
)
|
||||
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
|
||||
monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
|
||||
with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls:
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10)
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20)
|
||||
|
||||
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
|
||||
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
|
||||
|
||||
assert result_1 == "vector"
|
||||
assert result_2 == "vector"
|
||||
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
|
||||
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
|
||||
assert dataset_without_index.index_struct is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field", "message"),
|
||||
[
|
||||
("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"),
|
||||
("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"),
|
||||
("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"),
|
||||
("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"),
|
||||
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
|
||||
],
|
||||
)
|
||||
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None
|
||||
)
|
||||
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
|
||||
monkeypatch.setattr(vikingdb_module.dify_config, field, None)
|
||||
|
||||
with pytest.raises(ValueError, match=message):
|
||||
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
|
||||
@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in:
|
||||
- Full-text search result metadata (search_by_full_text)
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from core.rag.models.document import Document
|
||||
@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
weaviate_vector_module._weaviate_client = None
|
||||
|
||||
def test_config_requires_endpoint(self):
|
||||
with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"):
|
||||
WeaviateConfig(endpoint="")
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def _create_weaviate_vector(self, mock_weaviate_module):
|
||||
"""Helper to create a WeaviateVector instance with mocked client."""
|
||||
@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
)
|
||||
return wv, mock_client
|
||||
|
||||
def test_shutdown_client_logs_debug_when_close_fails(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.close.side_effect = RuntimeError("close failed")
|
||||
weaviate_vector_module._weaviate_client = mock_client
|
||||
|
||||
with patch.object(weaviate_vector_module.logger, "debug") as mock_debug:
|
||||
weaviate_vector_module._shutdown_weaviate_client()
|
||||
|
||||
assert weaviate_vector_module._weaviate_client is None
|
||||
mock_client.close.assert_called_once()
|
||||
mock_debug.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
|
||||
def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect):
|
||||
cached_client = MagicMock()
|
||||
cached_client.is_ready.return_value = True
|
||||
weaviate_vector_module._weaviate_client = cached_client
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
client = wv._init_client(self.config)
|
||||
|
||||
assert client is cached_client
|
||||
mock_connect.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
|
||||
def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect):
|
||||
cached_client = MagicMock()
|
||||
cached_client.is_ready.side_effect = [False, True]
|
||||
weaviate_vector_module._weaviate_client = cached_client
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
client = wv._init_client(self.config)
|
||||
|
||||
assert client is cached_client
|
||||
mock_connect.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token")
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
|
||||
def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_connect.return_value = mock_client
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
config = WeaviateConfig(
|
||||
endpoint="https://weaviate.example.com",
|
||||
grpc_endpoint="grpc.example.com:6000",
|
||||
api_key="test-key",
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
client = wv._init_client(config)
|
||||
|
||||
assert client is mock_client
|
||||
assert mock_connect.call_args.kwargs == {
|
||||
"http_host": "weaviate.example.com",
|
||||
"http_port": 443,
|
||||
"http_secure": True,
|
||||
"grpc_host": "grpc.example.com",
|
||||
"grpc_port": 6000,
|
||||
"grpc_secure": False,
|
||||
"auth_credentials": "auth-token",
|
||||
"skip_init_checks": True,
|
||||
}
|
||||
mock_api_key.assert_called_once_with("test-key")
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
|
||||
def test_init_client_raises_when_database_not_ready(self, mock_connect):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = False
|
||||
mock_connect.return_value = mock_client
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
with pytest.raises(ConnectionError, match="Vector database is not ready"):
|
||||
wv._init_client(self.config)
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_init(self, mock_weaviate_module):
|
||||
"""Test WeaviateVector initialization stores attributes including doc_type."""
|
||||
@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
assert wv._collection_name == self.collection_name
|
||||
assert "doc_type" in wv._attributes
|
||||
|
||||
def test_get_type_and_to_index_struct(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
|
||||
assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE
|
||||
assert wv.to_index_struct() == {
|
||||
"type": weaviate_vector_module.VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": self.collection_name},
|
||||
}
|
||||
|
||||
def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self):
|
||||
dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1")
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
assert wv.get_collection_name(dataset) == "ExistingCollection_Node"
|
||||
|
||||
def test_get_collection_name_generates_name_from_dataset_id(self):
|
||||
dataset = SimpleNamespace(index_struct_dict=None, id="ds-2")
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"):
|
||||
assert wv.get_collection_name(dataset) == "Generated_Node"
|
||||
|
||||
def test_create_calls_collection_setup_then_add_texts(self):
|
||||
doc = Document(page_content="hello", metadata={})
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._create_collection = MagicMock()
|
||||
wv.add_texts = MagicMock()
|
||||
|
||||
wv.create([doc], [[0.1, 0.2]])
|
||||
|
||||
wv._create_collection.assert_called_once()
|
||||
wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]])
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config")
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
f"doc_type should be in collection schema properties, got: {property_names}"
|
||||
)
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
|
||||
def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis):
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._ensure_properties = MagicMock()
|
||||
|
||||
wv._create_collection()
|
||||
|
||||
wv._client.collections.exists.assert_not_called()
|
||||
wv._ensure_properties.assert_not_called()
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
|
||||
def test_create_collection_logs_and_reraises_errors(self, mock_redis):
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.side_effect = RuntimeError("create failed")
|
||||
|
||||
with patch.object(weaviate_vector_module.logger, "exception") as mock_exception:
|
||||
with pytest.raises(RuntimeError, match="create failed"):
|
||||
wv._create_collection()
|
||||
|
||||
mock_exception.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module):
|
||||
"""Test that _ensure_properties adds doc_type when it's missing from existing schema."""
|
||||
@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
added_names = [call.args[0].name for call in add_calls]
|
||||
assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = [SimpleNamespace(name="text")]
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
wv._ensure_properties()
|
||||
|
||||
add_calls = mock_col.config.add_property.call_args_list
|
||||
added_names = [call.args[0].name for call in add_calls]
|
||||
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
|
||||
"""Test that _ensure_properties does not add doc_type when it already exists."""
|
||||
@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
# No properties should be added
|
||||
mock_col.config.add_property.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = []
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
mock_col.config.add_property.side_effect = RuntimeError("cannot add")
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
|
||||
wv._ensure_properties()
|
||||
|
||||
assert mock_warning.call_count == 4
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_vector returns doc_type in document metadata.
|
||||
@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata.get("doc_type") == "image"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {
|
||||
"text": "fallback distance result",
|
||||
"document_id": "doc-1",
|
||||
"doc_id": "segment-1",
|
||||
}
|
||||
mock_obj.metadata = None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
mock_col.query.near_vector.return_value = mock_result
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_vector(
|
||||
query_vector=[0.2] * 3,
|
||||
document_ids_filter=["doc-1"],
|
||||
top_k=2,
|
||||
score_threshold=-1,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.0
|
||||
assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = False
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text also returns doc_type in document metadata."""
|
||||
@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata.get("doc_type") == "image"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"}
|
||||
mock_obj.vector = [0.3, 0.4]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
mock_col.query.bm25.return_value = mock_result
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"])
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].vector == [0.3, 0.4]
|
||||
assert mock_col.query.bm25.call_args.kwargs["filters"] is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = False
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
assert wv.search_by_full_text(query="missing") == []
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
|
||||
"""Test that add_texts includes doc_type from document metadata in stored properties."""
|
||||
@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
stored_props = call_kwargs.kwargs.get("properties")
|
||||
assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module):
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
mock_batch = MagicMock()
|
||||
mock_batch.__enter__ = MagicMock(return_value=mock_batch)
|
||||
mock_batch.__exit__ = MagicMock(return_value=False)
|
||||
mock_col.batch.dynamic.return_value = mock_batch
|
||||
|
||||
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
|
||||
doc = Document(page_content="text", metadata={"created_at": created_at})
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]),
|
||||
patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"),
|
||||
):
|
||||
ids = wv.add_texts(documents=[doc], embeddings=[[]])
|
||||
|
||||
assert ids == ["fallback-uuid"]
|
||||
call_kwargs = mock_batch.add_object.call_args
|
||||
assert call_kwargs.kwargs["uuid"] == "fallback-uuid"
|
||||
assert call_kwargs.kwargs["vector"] is None
|
||||
assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat()
|
||||
|
||||
def test_is_uuid_handles_invalid_values(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True
|
||||
assert wv._is_uuid("not-a-uuid") is False
|
||||
|
||||
def test_delete_by_metadata_field_returns_when_collection_is_missing(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.return_value = False
|
||||
|
||||
wv.delete_by_metadata_field("doc_id", "segment-1")
|
||||
|
||||
wv._client.collections.use.assert_not_called()
|
||||
|
||||
def test_delete_by_metadata_field_deletes_matching_objects(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
wv._client.collections.use.return_value = mock_col
|
||||
|
||||
wv.delete_by_metadata_field("doc_id", "segment-1")
|
||||
|
||||
mock_col.data.delete_many.assert_called_once()
|
||||
|
||||
def test_delete_removes_collection_when_present(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.return_value = True
|
||||
|
||||
wv.delete()
|
||||
|
||||
wv._client.collections.delete.assert_called_once_with(self.collection_name)
|
||||
|
||||
def test_text_exists_handles_missing_and_present_documents(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.side_effect = [False, True]
|
||||
mock_col = MagicMock()
|
||||
wv._client.collections.use.return_value = mock_col
|
||||
mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()])
|
||||
|
||||
assert wv.text_exists("segment-1") is False
|
||||
assert wv.text_exists("segment-1") is True
|
||||
|
||||
def test_delete_by_ids_handles_missing_collections_and_404s(self):
|
||||
class FakeUnexpectedStatusCodeError(Exception):
|
||||
def __init__(self, status_code):
|
||||
super().__init__(f"status={status_code}")
|
||||
self.status_code = status_code
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.side_effect = [False, True]
|
||||
mock_col = MagicMock()
|
||||
wv._client.collections.use.return_value = mock_col
|
||||
mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None]
|
||||
|
||||
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
|
||||
wv.delete_by_ids(["ignored"])
|
||||
wv.delete_by_ids(["missing-id", "ok-id"])
|
||||
|
||||
assert mock_col.data.delete_by_id.call_count == 2
|
||||
|
||||
def test_delete_by_ids_reraises_non_404_errors(self):
|
||||
class FakeUnexpectedStatusCodeError(Exception):
|
||||
def __init__(self, status_code):
|
||||
super().__init__(f"status={status_code}")
|
||||
self.status_code = status_code
|
||||
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
wv._collection_name = self.collection_name
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
wv._client.collections.use.return_value = mock_col
|
||||
mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500)
|
||||
|
||||
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
|
||||
with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"):
|
||||
wv.delete_by_ids(["bad-id"])
|
||||
|
||||
def test_json_serializable_converts_datetime(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
|
||||
|
||||
assert wv._json_serializable(created_at) == created_at.isoformat()
|
||||
assert wv._json_serializable("plain") == "plain"
|
||||
|
||||
|
||||
class TestVectorDefaultAttributes(unittest.TestCase):
|
||||
"""Tests for Vector class default attributes list."""
|
||||
@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase):
|
||||
assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}"
|
||||
|
||||
|
||||
class TestWeaviateVectorFactory(unittest.TestCase):
|
||||
def test_init_vector_uses_existing_dataset_index_struct(self):
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}},
|
||||
index_struct=None,
|
||||
)
|
||||
attributes = ["doc_id"]
|
||||
|
||||
with (
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88),
|
||||
patch(
|
||||
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
|
||||
) as mock_vector,
|
||||
):
|
||||
factory = weaviate_vector_module.WeaviateVectorFactory()
|
||||
result = factory.init_vector(dataset, attributes, MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
config = mock_vector.call_args.kwargs["config"]
|
||||
assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node"
|
||||
assert mock_vector.call_args.kwargs["attributes"] == attributes
|
||||
assert config.endpoint == "http://localhost:8080"
|
||||
assert config.grpc_endpoint == "localhost:50051"
|
||||
assert config.api_key == "api-key"
|
||||
assert config.batch_size == 88
|
||||
assert dataset.index_struct is None
|
||||
|
||||
def test_init_vector_generates_collection_and_updates_index_struct(self):
|
||||
dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
|
||||
attributes = ["doc_id", "doc_type"]
|
||||
|
||||
with (
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None),
|
||||
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100),
|
||||
patch.object(
|
||||
weaviate_vector_module.Dataset,
|
||||
"gen_collection_name_by_id",
|
||||
return_value="GeneratedCollection_Node",
|
||||
),
|
||||
patch(
|
||||
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
|
||||
) as mock_vector,
|
||||
):
|
||||
factory = weaviate_vector_module.WeaviateVectorFactory()
|
||||
result = factory.init_vector(dataset, attributes, MagicMock())
|
||||
|
||||
assert result == "vector"
|
||||
assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node"
|
||||
assert json.loads(dataset.index_struct) == {
|
||||
"type": weaviate_vector_module.VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": "GeneratedCollection_Node"},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1164,7 +1164,7 @@ class TestConversationStatusCount:
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
@ -1189,7 +1189,7 @@ class TestConversationStatusCount:
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
@ -1274,7 +1274,7 @@ class TestConversationStatusCount:
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
@ -1337,7 +1337,7 @@ class TestConversationStatusCount:
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
@ -1382,7 +1382,7 @@ class TestConversationStatusCount:
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
@ -1438,7 +1438,7 @@ class TestConversationStatusCount:
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
|
||||
@ -13,6 +13,10 @@ import pytest
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def _oauth_proxy_setex_calls(redis_client) -> list:
|
||||
return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")]
|
||||
|
||||
|
||||
class TestCreateProxyContext:
|
||||
def test_stores_context_in_redis_with_ttl(self):
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
@ -22,8 +26,9 @@ class TestCreateProxyContext:
|
||||
assert context_id # non-empty UUID string
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
redis_client.setex.assert_called_once()
|
||||
call_args = redis_client.setex.call_args
|
||||
oauth_calls = _oauth_proxy_setex_calls(redis_client)
|
||||
assert len(oauth_calls) == 1
|
||||
call_args = oauth_calls[0]
|
||||
key = call_args[0][0]
|
||||
ttl = call_args[0][1]
|
||||
stored_data = json.loads(call_args[0][2])
|
||||
|
||||
@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch
|
||||
|
||||
def test_import_app_pending_stores_import_info_in_redis():
|
||||
service = AppDslService(MagicMock())
|
||||
app_dsl_service.redis_client.setex.reset_mock()
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch):
|
||||
created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
|
||||
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
|
||||
|
||||
app_dsl_service.redis_client.delete.reset_mock()
|
||||
result = service.confirm_import(import_id="import-1", account=_account_mock())
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id == "confirmed-app"
|
||||
app_dsl_service.redis_client.delete.assert_called_once()
|
||||
app_dsl_service.redis_client.delete.assert_called_once_with(
|
||||
f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1"
|
||||
)
|
||||
|
||||
|
||||
def test_confirm_import_exception_returns_failed(monkeypatch):
|
||||
|
||||
@ -405,7 +405,7 @@ class TestAudioServiceTTS:
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
@ -549,7 +549,7 @@ class TestAudioServiceTTS:
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
@ -564,7 +564,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
@ -585,7 +585,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
|
||||
@ -313,7 +313,8 @@ class TestEmailDeliveryTestHandler:
|
||||
recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")],
|
||||
)
|
||||
|
||||
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
|
||||
with patch.object(dify_config, "APP_WEB_URL", "http://example.com"):
|
||||
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
|
||||
|
||||
assert subs["node_title"] == "title"
|
||||
assert subs["form_content"] == "content"
|
||||
|
||||
@ -316,7 +316,7 @@ class TestTagServiceRetrieval:
|
||||
- get_tags_by_target_id: Get all tags bound to a specific target
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags with their binding counts.
|
||||
@ -373,7 +373,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify database query was called
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags filtered by keyword (case-insensitive).
|
||||
@ -427,7 +427,7 @@ class TestTagServiceRetrieval:
|
||||
# 2. Additional WHERE clause for keyword filtering
|
||||
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving target IDs by tag IDs.
|
||||
@ -483,7 +483,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify both queries were executed
|
||||
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that empty tag_ids returns empty list.
|
||||
@ -511,7 +511,7 @@ class TestTagServiceRetrieval:
|
||||
assert results == [], "Should return empty list for empty input"
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tag_by_tag_name(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags by name.
|
||||
@ -553,7 +553,7 @@ class TestTagServiceRetrieval:
|
||||
assert len(results) == 1, "Should find exactly one tag"
|
||||
assert results[0].name == tag_name, "Tag name should match"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that missing tag_type or tag_name returns empty list.
|
||||
@ -581,7 +581,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify no database queries were executed
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tags_by_target_id(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags associated with a specific target.
|
||||
@ -654,7 +654,7 @@ class TestTagServiceCRUD:
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.uuid.uuid4", autospec=True)
|
||||
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
@ -743,7 +743,7 @@ class TestTagServiceCRUD:
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test updating a tag name.
|
||||
@ -795,7 +795,7 @@ class TestTagServiceCRUD:
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_update_tags_raises_error_for_duplicate_name(
|
||||
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
|
||||
):
|
||||
@ -827,7 +827,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(ValueError, match="Tag name already exists"):
|
||||
TagService.update_tags(args, tag_id="tag-123")
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that updating a non-existent tag raises NotFound.
|
||||
@ -859,7 +859,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(NotFound, match="Tag not found"):
|
||||
TagService.update_tags(args, tag_id="nonexistent")
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_get_tag_binding_count(self, mock_db_session, factory):
|
||||
"""
|
||||
Test getting the count of bindings for a tag.
|
||||
@ -895,7 +895,7 @@ class TestTagServiceCRUD:
|
||||
# Verify count matches expectation
|
||||
assert result == expected_count, "Binding count should match"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_delete_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test deleting a tag and its bindings.
|
||||
@ -951,7 +951,7 @@ class TestTagServiceCRUD:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that deleting a non-existent tag raises NotFound.
|
||||
@ -999,7 +999,7 @@ class TestTagServiceBindings:
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test creating tag bindings.
|
||||
@ -1050,7 +1050,7 @@ class TestTagServiceBindings:
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test that saving duplicate bindings is idempotent.
|
||||
@ -1090,7 +1090,7 @@ class TestTagServiceBindings:
|
||||
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test deleting a tag binding.
|
||||
@ -1138,7 +1138,7 @@ class TestTagServiceBindings:
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test that deleting a non-existent binding is a no-op.
|
||||
@ -1175,7 +1175,7 @@ class TestTagServiceBindings:
|
||||
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that a dataset target exists.
|
||||
@ -1216,7 +1216,7 @@ class TestTagServiceBindings:
|
||||
mock_db_session.query.assert_called_once(), "Should query database for dataset"
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that an app target exists.
|
||||
@ -1257,7 +1257,7 @@ class TestTagServiceBindings:
|
||||
mock_db_session.query.assert_called_once(), "Should query database for app"
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_check_target_exists_raises_not_found_for_missing_dataset(
|
||||
self, mock_db_session, mock_current_user, factory
|
||||
):
|
||||
@ -1289,7 +1289,7 @@ class TestTagServiceBindings:
|
||||
TagService.check_target_exists("knowledge", "nonexistent")
|
||||
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.db.session")
|
||||
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test that missing app raises NotFound.
|
||||
|
||||
@ -59,6 +59,11 @@ def mock_redis():
|
||||
# Redis is already mocked globally in conftest.py
|
||||
# Reset it for each test
|
||||
redis_client.reset_mock()
|
||||
redis_client.get.reset_mock()
|
||||
redis_client.setex.reset_mock()
|
||||
redis_client.delete.reset_mock()
|
||||
redis_client.lpush.reset_mock()
|
||||
redis_client.rpop.reset_mock()
|
||||
redis_client.get.return_value = None
|
||||
redis_client.setex.return_value = True
|
||||
redis_client.delete.return_value = True
|
||||
|
||||
Reference in New Issue
Block a user