Mergin main into fix/chore-fix

This commit is contained in:
Yeuoly
2024-10-14 16:22:12 +08:00
433 changed files with 11823 additions and 2782 deletions

View File

@ -0,0 +1,49 @@
from typing import Any
import toml
def load_api_poetry_configs() -> dict[str, Any]:
pyproject_toml = toml.load("api/pyproject.toml")
return pyproject_toml["tool"]["poetry"]
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
configs = load_api_poetry_configs()
configs_by_group = {"main": configs}
for group_name in configs["group"]:
configs_by_group[group_name] = configs["group"][group_name]
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
return dependencies_by_group
def test_group_dependencies_sorted():
for group_name, dependencies in load_all_dependency_groups().items():
dependency_names = list(dependencies.keys())
expected_dependency_names = sorted(set(dependency_names))
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
assert expected_dependency_names == dependency_names, (
f"Dependencies in group {group_name} are not sorted. "
f"Check and fix [{section}] section in pyproject.toml file"
)
def test_group_dependencies_version_operator():
for group_name, dependencies in load_all_dependency_groups().items():
for dependency_name, specification in dependencies.items():
version_spec = specification if isinstance(specification, str) else specification["version"]
assert not version_spec.startswith("^"), (
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
f"'^' operator is too wide and not allowed in the version specification."
)
def test_duplicated_dependency_crossing_groups():
all_dependency_names: list[str] = []
for dependencies in load_all_dependency_groups().values():
dependency_names = list(dependencies.keys())
all_dependency_names.extend(dependency_names)
expected_all_dependency_names = set(all_dependency_names)
assert sorted(expected_all_dependency_names) == sorted(
all_dependency_names
), "Duplicated dependencies crossing groups are found"

View File

@ -1,46 +0,0 @@
import os
from collections.abc import Callable
from typing import Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from openai.resources.moderations import Moderations
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
def mock_openai(
monkeypatch: MonkeyPatch,
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
"""
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "moderation" in methods:
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -1,101 +0,0 @@
import re
from typing import Any, Literal, Union
from openai._types import NOT_GIVEN, NotGiven
from openai.resources.moderations import Moderations
from openai.types import ModerationCreateResponse
from openai.types.moderation import Categories, CategoryScores, Moderation
from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockModerationClass:
def moderation_create(
self: Moderations,
*,
input: Union[str, list[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError("Invalid API key")
for text in input:
result = []
if "kill" in text:
moderation_categories = {
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
"harassment": 1.0,
"harassment/threatening": 1.0,
"hate": 1.0,
"hate/threatening": 1.0,
"self-harm": 1.0,
"self-harm/instructions": 1.0,
"self-harm/intent": 1.0,
"sexual": 1.0,
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
}
result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
else:
moderation_categories = {
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
"harassment": 0.0,
"harassment/threatening": 0.0,
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"self-harm/instructions": 0.0,
"self-harm/intent": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
}
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

View File

@ -1,44 +0,0 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAIModerationModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"})
model.validate_credentials(
model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
)
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAIModerationModel()
result = model.invoke(
model="text-moderation-stable",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
text="hello",
user="abc-123",
)
assert isinstance(result, bool)
assert result is False
result = model.invoke(
model="text-moderation-stable",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
text="i will kill you",
user="abc-123",
)
assert isinstance(result, bool)
assert result is True

View File

@ -1,57 +0,0 @@
import logging
import os
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderExtension, ModelProviderFactory
logger = logging.getLogger(__name__)
def test_get_providers():
factory = ModelProviderFactory("test_tenant")
providers = factory.get_providers()
for provider in providers:
logger.debug(provider)
assert len(providers) >= 1
assert isinstance(providers[0], ProviderEntity)
def test_get_models():
factory = ModelProviderFactory("test_tenant")
providers = factory.get_models(
model_type=ModelType.LLM,
provider_configs=[
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
],
)
logger.debug(providers)
assert len(providers) >= 1
assert isinstance(providers[0], SimpleProviderEntity)
# all provider models type equals to ModelType.LLM
for provider in providers:
for provider_model in provider.models:
assert provider_model.model_type == ModelType.LLM
providers = factory.get_models(
provider="openai",
provider_configs=[
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
],
)
assert len(providers) == 1
assert isinstance(providers[0], SimpleProviderEntity)
assert providers[0].provider == "openai"
def test_provider_credentials_validate():
factory = ModelProviderFactory("test_tenant")
factory.provider_credentials_validate(
provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
)

View File

@ -1,11 +0,0 @@
import os
import tiktoken
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
def test_tiktoken():
os.environ["TIKTOKEN_CACHE_DIR"] = "/tmp/.tiktoken_cache"
GPT2Tokenizer.get_num_tokens("Hello, world!")
assert tiktoken.registry.ENCODING_CONSTRUCTORS is not None

View File

@ -1,25 +0,0 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
def test_validate_provider_credentials():
provider = VoyageProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
"model": "voyage-3",
"usage": {"total_tokens": 1},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})

View File

@ -1,92 +0,0 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
def test_validate_credentials():
model = VoyageRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="rerank-lite-1",
credentials={"api_key": "invalid_key"},
)
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{
"relevance_score": 0.546875,
"index": 0,
"document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
"States Census, Carson City had a population of 55,274.",
},
{
"relevance_score": 0.4765625,
"index": 1,
"document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
"Pacific Ocean that are a political division controlled by the United States. Its "
"capital is Saipan.",
},
],
"model": "rerank-lite-1",
"usage": {"total_tokens": 96},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(
model="rerank-lite-1",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
)
def test_invoke_model():
model = VoyageRerankModel()
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{
"relevance_score": 0.84375,
"index": 0,
"document": "Kasumi is a girl name of Japanese origin meaning mist.",
},
{
"relevance_score": 0.4765625,
"index": 1,
"document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
"leads a team named PopiParty.",
},
],
"model": "rerank-lite-1",
"usage": {"total_tokens": 59},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="rerank-lite-1",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl name of Japanese origin meaning mist.",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
"PopiParty.",
],
score_threshold=0.5,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.5

View File

@ -1,70 +0,0 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
def test_validate_credentials():
model = VoyageTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
"model": "voyage-3",
"usage": {"total_tokens": 1},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
def test_invoke_model():
model = VoyageTextEmbeddingModel()
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
],
"model": "voyage-3",
"usage": {"total_tokens": 2},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="voyage-3",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = VoyageTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="voyage-3",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
texts=["ping"],
)
assert num_tokens == 1

View File

@ -0,0 +1,154 @@
import os
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
from requests.adapters import HTTPAdapter
class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: HTTPAdapter = None,
):
self._conn = None
self._config = None
def list_databases(self, config=None) -> list[Database]:
return [
Database(
conn=self._conn,
database_name="dify",
config=self._config,
)
]
def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self._conn, database_name=database_name, config=config)
def list_table(self, config=None) -> list[Table]:
return []
def drop_table(self, table_name: str, config=None):
return {"code": 0, "msg": "Success"}
def create_table(
self,
table_name: str,
replication: int,
partition: int,
schema,
enable_dynamic_field=False,
description: str = "",
config=None,
) -> Table:
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
def describe_table(self, table_name: str, config=None) -> Table:
return Table(
self,
table_name,
3,
1,
None,
enable_dynamic_field=False,
description="table for dify",
config=config,
state=TableState.NORMAL,
)
def upsert(self, rows, config=None):
return {"code": 0, "msg": "operation success", "affectedCount": 1}
def rebuild_index(self, index_name: str, config=None):
return {"code": 0, "msg": "Success"}
def describe_index(self, index_name: str, config=None):
return VectorIndex(
index_name=index_name,
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.L2,
params=HNSWParams(m=16, efconstruction=200),
auto_build=False,
state=IndexState.NORMAL,
)
def query(
self,
primary_key,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
}
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"}
def search(
self,
anns,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"rows": [
{
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,
}
],
"code": 0,
"msg": "Success",
}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@ -48,7 +48,7 @@ class MockTcvectordbClass:
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
timeout: Optional[float] = None,
) -> Collection:
return Collection(
self,
@ -97,9 +97,9 @@ class MockTcvectordbClass:
def collection_delete(
self,
document_ids: list[str] = None,
document_ids: Optional[list[str]] = None,
filter: Filter = None,
timeout: float = None,
timeout: Optional[float] = None,
):
return {"code": 0, "msg": "operation success"}

View File

@ -0,0 +1,215 @@
import os
from typing import Union
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import (
Collection,
Data,
DistanceType,
Field,
FieldType,
Index,
IndexType,
QuantType,
VectorIndexParams,
VikingDBService,
)
from core.rag.datasource.vdb.field import Field as vdb_Field
class MockVikingDBClass:
def __init__(
self,
host="api-vikingdb.volces.com",
region="cn-north-1",
ak="",
sk="",
scheme="http",
connection_timeout=30,
socket_timeout=30,
proxy=None,
):
self._viking_db_service = MagicMock()
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
def get_collection(self, collection_name) -> Collection:
return Collection(
collection_name=collection_name,
description="Collection For Dify",
viking_db_service=self._viking_db_service,
primary_key=vdb_Field.PRIMARY_KEY.value,
fields=[
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
],
indexes=[
Index(
collection_name=collection_name,
index_name=f"{collection_name}_idx",
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
scalar_index=None,
stat=None,
viking_db_service=self._viking_db_service,
)
],
)
def drop_collection(self, collection_name):
assert collection_name != ""
def create_collection(self, collection_name, fields, description="") -> Collection:
return Collection(
collection_name=collection_name,
description=description,
primary_key=vdb_Field.PRIMARY_KEY.value,
viking_db_service=self._viking_db_service,
fields=fields,
)
def get_index(self, collection_name, index_name) -> Index:
return Index(
collection_name=collection_name,
index_name=index_name,
viking_db_service=self._viking_db_service,
stat=None,
scalar_index=None,
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
)
def create_index(
self,
collection_name,
index_name,
vector_index=None,
cpu_quota=2,
description="",
partition_by="",
scalar_index=None,
shard_count=None,
shard_policy=None,
):
return Index(
collection_name=collection_name,
index_name=index_name,
vector_index=vector_index,
cpu_quota=cpu_quota,
description=description,
partition_by=partition_by,
scalar_index=scalar_index,
shard_count=shard_count,
shard_policy=shard_policy,
viking_db_service=self._viking_db_service,
stat=None,
)
def drop_index(self, collection_name, index_name):
assert collection_name != ""
assert index_name != ""
def upsert_data(self, data: Union[Data, list[Data]]):
assert data is not None
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
return Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: "{}",
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: id,
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id=id,
)
def delete_data(self, id: Union[str, list[str], int, list[int]]):
assert id is not None
def search_by_vector(
self,
vector,
sparse_vectors=None,
filter=None,
limit=10,
output_fields=None,
partition="default",
dense_weight=None,
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: vector,
},
id="test_id",
score=0.10,
)
]
def search(
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id="test_id",
score=0.10,
)
]
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@ -0,0 +1,36 @@
from unittest.mock import MagicMock
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class BaiduVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = BaiduVector(
"dify",
BaiduConfig(
endpoint="http://127.0.0.1:5287",
account="root",
api_key="dify",
database="dify",
shard=1,
replicas=3,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
BaiduVectorTest().run_all_tests()

View File

@ -1,5 +1,4 @@
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,

View File

@ -1,4 +1,3 @@
import random
import uuid
from unittest.mock import MagicMock

View File

@ -0,0 +1,37 @@
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class VikingDBVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = VikingDBVector(
"test_collection",
"test_group",
config=VikingDBConfig(
access_key="test_access_key",
host="test_host",
region="test_region",
scheme="test_scheme",
secret_key="test_secret_key",
connection_timeout=30,
socket_timeout=30,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
assert len(ids) > 0
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
VikingDBVectorTest().run_all_tests()

View File

@ -1,5 +1,5 @@
import os
from typing import Literal, Optional
from typing import Literal
import pytest
from _pytest.monkeypatch import MonkeyPatch

View File

@ -1,4 +1,3 @@
import json
import os
import time
import uuid

View File

@ -0,0 +1,38 @@
import pytest
from controllers.console.version import _has_new_version
@pytest.mark.parametrize(
("latest_version", "current_version", "expected"),
[
("1.0.1", "1.0.0", True),
("1.1.0", "1.0.0", True),
("2.0.0", "1.9.9", True),
("1.0.0", "1.0.0", False),
("1.0.0", "1.0.1", False),
("1.0.0", "2.0.0", False),
("1.0.1", "1.0.0-beta", True),
("1.0.0", "1.0.0-alpha", True),
("1.0.0-beta", "1.0.0-alpha", True),
("1.0.0", "1.0.0-rc1", True),
("1.0.0", "0.9.9", True),
("1.0.0", "1.0.0-dev", True),
],
)
def test_has_new_version(latest_version, current_version, expected):
assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected
def test_has_new_version_invalid_input():
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0", current_version="1.0.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0.0", current_version="1.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="invalid", current_version="1.0.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0.0", current_version="invalid")

View File

@ -2,7 +2,6 @@ import pytest
from pydantic import ValidationError
from core.app.segments import (
ArrayAnyVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,

View File

@ -1,9 +1,6 @@
import os
from unittest import mock
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.models.document import Document
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response