mirror of
https://github.com/langgenius/dify.git
synced 2026-03-23 15:27:53 +08:00
test: add unit tests for services-part-1 (#33050)
This commit is contained in:
@ -704,7 +704,7 @@ class MCPToolManageService:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
raise error
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
|
||||
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
@ -0,0 +1,558 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DocumentStub:
|
||||
id: str
|
||||
name: str
|
||||
uploader: str
|
||||
upload_date: datetime
|
||||
last_update_date: datetime
|
||||
data_source_type: str
|
||||
doc_metadata: dict[str, object] | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
mocked_db = mocker.patch("services.metadata_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.metadata_service.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(mocker: MockerFixture) -> MagicMock:
|
||||
mock_user = SimpleNamespace(id="user-1")
|
||||
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
|
||||
|
||||
|
||||
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
|
||||
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
|
||||
return _DocumentStub(
|
||||
id=document_id,
|
||||
name=f"doc-{document_id}",
|
||||
uploader="qa@example.com",
|
||||
upload_date=now,
|
||||
last_update_date=now,
|
||||
data_source_type="upload_file",
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _dataset(**kwargs: Any) -> Dataset:
|
||||
return cast(Dataset, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="x" * 256)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="priority")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_persist_metadata_when_input_is_valid(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="number", name="score")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.dataset_id == "dataset-1"
|
||||
assert result.type == "number"
|
||||
assert result.name == "score"
|
||||
assert result.created_by == "user-1"
|
||||
mock_db.session.add.assert_called_once_with(result)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
too_long_name = "x" * 256
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
|
||||
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
|
||||
|
||||
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
|
||||
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
|
||||
|
||||
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
|
||||
mock_get_documents.return_value = [doc_1, doc_2]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert metadata.name == "new_name"
|
||||
assert metadata.updated_by == "user-1"
|
||||
assert metadata.updated_at == fixed_now
|
||||
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
|
||||
assert doc_2.doc_metadata == {"new_name": None}
|
||||
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
|
||||
bindings = [SimpleNamespace(document_id="doc-1")]
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_metadata, query_bindings]
|
||||
|
||||
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert document.doc_metadata == {"remaining": "value"}
|
||||
mock_db.session.delete.assert_called_once_with(metadata)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_delete_metadata_should_return_none_when_metadata_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "missing-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
|
||||
# Arrange
|
||||
expected_names = {
|
||||
BuiltInField.document_name,
|
||||
BuiltInField.uploader,
|
||||
BuiltInField.upload_date,
|
||||
BuiltInField.last_update_date,
|
||||
BuiltInField.source,
|
||||
}
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_built_in_fields()
|
||||
|
||||
# Assert
|
||||
assert {item["name"] for item in result} == expected_names
|
||||
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
doc_1 = _build_document("1", {"custom": "value"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[doc_1, doc_2],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert doc_1.doc_metadata is not None
|
||||
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
|
||||
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert doc_2.doc_metadata is not None
|
||||
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document(
|
||||
"1",
|
||||
{
|
||||
BuiltInField.document_name: "doc",
|
||||
BuiltInField.uploader: "user",
|
||||
BuiltInField.upload_date: 1.0,
|
||||
BuiltInField.last_update_date: 2.0,
|
||||
BuiltInField.source: MetadataDataSource.upload_file,
|
||||
"custom": "keep",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[document],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is False
|
||||
assert document.doc_metadata == {"custom": "keep"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
document = _build_document("1", {"legacy": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
delete_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
delete_chain.delete.return_value = 1
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
|
||||
partial_update=False,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata == {"priority": "high"}
|
||||
delete_chain.delete.assert_called_once()
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document("1", {"existing": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata is not None
|
||||
assert document.doc_metadata["existing"] == "value"
|
||||
assert document.doc_metadata["new_key"] == "new_value"
|
||||
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
|
||||
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dataset_id", "document_id", "expected_key"),
|
||||
[
|
||||
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
|
||||
(None, "doc-1", "document_metadata_lock_doc-1"),
|
||||
],
|
||||
)
|
||||
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
|
||||
dataset_id: str | None,
|
||||
document_id: str | None,
|
||||
expected_key: str,
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="document metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(
|
||||
id="dataset-1",
|
||||
built_in_field_enabled=True,
|
||||
doc_metadata=[
|
||||
{"id": "meta-1", "name": "priority", "type": "string"},
|
||||
{"id": "built-in", "name": "ignored", "type": "string"},
|
||||
{"id": "meta-2", "name": "score", "type": "number"},
|
||||
],
|
||||
)
|
||||
count_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
count_chain.count.side_effect = [3, 1]
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result["built_in_field_enabled"] is True
|
||||
assert result["doc_metadata"] == [
|
||||
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
|
||||
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
|
||||
mock_db.session.query.assert_not_called()
|
||||
@ -0,0 +1,808 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FieldModelSchema,
|
||||
FormType,
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
)
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
def _build_provider_credential_schema() -> ProviderCredentialSchema:
|
||||
return ProviderCredentialSchema(
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_model_credential_schema() -> ModelCredentialSchema:
|
||||
return ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
custom_provider: bool = False,
|
||||
load_balancing_enabled: bool | None = None,
|
||||
model_schema: ModelCredentialSchema | None = None,
|
||||
provider_schema: ProviderCredentialSchema | None = None,
|
||||
) -> MagicMock:
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_credential_schema=model_schema,
|
||||
provider_credential_schema=provider_schema,
|
||||
)
|
||||
provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider)
|
||||
provider_configuration.extract_secret_variables.return_value = ["api_key"]
|
||||
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials
|
||||
provider_configuration.get_provider_model_setting.return_value = (
|
||||
None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled)
|
||||
)
|
||||
return provider_configuration
|
||||
|
||||
|
||||
def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
|
||||
return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
|
||||
# Arrange
|
||||
provider_manager = MagicMock()
|
||||
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
|
||||
svc = ModelLoadBalancingService()
|
||||
svc.provider_manager = provider_manager
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.model_load_balancing_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "expected_provider_method"),
|
||||
[
|
||||
("enable_model_load_balancing", "enable_model_load_balancing"),
|
||||
("disable_model_load_balancing", "disable_model_load_balancing"),
|
||||
],
|
||||
)
|
||||
def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists(
|
||||
method_name: str,
|
||||
expected_provider_method: str,
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, expected_provider_method).assert_called_once_with(
|
||||
model="gpt-4o-mini", model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name",
|
||||
["enable_model_load_balancing", "disable_model_load_balancing"],
|
||||
)
|
||||
def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing(
|
||||
method_name: str,
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(
|
||||
custom_provider=True,
|
||||
load_balancing_enabled=True,
|
||||
provider_schema=_build_provider_credential_schema(),
|
||||
)
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
config = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="primary",
|
||||
encrypted_config=json.dumps({"api_key": "encrypted-key"}),
|
||||
credential_id="cred-1",
|
||||
enabled=True,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||
return_value=("rsa", "cipher"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
|
||||
return_value="plain-key",
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
|
||||
return_value=(False, 0),
|
||||
)
|
||||
|
||||
# Act
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 2
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[1]["name"] == "primary"
|
||||
assert configs[1]["credentials"] == {"api_key": "plain-key"}
|
||||
assert mock_db.session.add.call_count == 1
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(
|
||||
custom_provider=True,
|
||||
load_balancing_enabled=None,
|
||||
provider_schema=_build_provider_credential_schema(),
|
||||
)
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
normal_config = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="normal",
|
||||
encrypted_config=json.dumps({"api_key": "bad-encrypted"}),
|
||||
credential_id="cred-1",
|
||||
enabled=True,
|
||||
)
|
||||
inherit_config = SimpleNamespace(
|
||||
id="cfg-2",
|
||||
name="__inherit__",
|
||||
encrypted_config="not-json",
|
||||
credential_id=None,
|
||||
enabled=False,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
||||
normal_config,
|
||||
inherit_config,
|
||||
]
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||
return_value=("rsa", "cipher"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
|
||||
side_effect=ValueError("cannot decrypt"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
|
||||
return_value=(True, 15),
|
||||
)
|
||||
|
||||
# Act
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
config_from="predefined-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert is_enabled is False
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[0]["credentials"] == {}
|
||||
assert configs[1]["credentials"] == {"api_key": "bad-encrypted"}
|
||||
assert configs[1]["in_cooldown"] is True
|
||||
assert configs[1]["ttl"] == 15
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_return_none_when_config_not_found(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: {
|
||||
"masked": credentials.get("api_key", "")
|
||||
}
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = config
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
# Assert
|
||||
assert result == {
|
||||
"id": "cfg-1",
|
||||
"name": "primary",
|
||||
"credentials": {"masked": ""},
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def test_init_inherit_config_should_create_and_persist_inherit_configuration(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
model_type = ModelType.LLM
|
||||
|
||||
# Act
|
||||
inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type)
|
||||
|
||||
# Assert
|
||||
assert inherit_config.tenant_id == "tenant-1"
|
||||
assert inherit_config.provider_name == "openai"
|
||||
assert inherit_config.model_name == "gpt-4o-mini"
|
||||
assert inherit_config.model_type == "text-generation"
|
||||
assert inherit_config.name == "__inherit__"
|
||||
mock_db.session.add.assert_called_once_with(inherit_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing configs"):
|
||||
service.update_load_balancing_configs( # type: ignore[arg-type]
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
cast(list[dict[str, object]], "invalid-configs"),
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config"):
|
||||
service.update_load_balancing_configs( # type: ignore[list-item]
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
cast(list[dict[str, object]], ["bad-item"]),
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"credential_id": "cred-1", "enabled": True}],
|
||||
"predefined-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config name"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config enabled"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "cfg-without-enabled"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
current_config = SimpleNamespace(id="cfg-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [current_config]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"id": "cfg-2", "name": "invalid", "enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None)
|
||||
mock_db.session.scalars.return_value.all.return_value = [existing_config]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "new-config", "enabled": True, "credentials": "bad"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config_1 = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="existing-one",
|
||||
enabled=True,
|
||||
encrypted_config=json.dumps({"api_key": "old"}),
|
||||
updated_at=None,
|
||||
)
|
||||
existing_config_2 = SimpleNamespace(
|
||||
id="cfg-2",
|
||||
name="existing-two",
|
||||
enabled=True,
|
||||
encrypted_config=None,
|
||||
updated_at=None,
|
||||
)
|
||||
mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2]
|
||||
mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"})
|
||||
mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache")
|
||||
|
||||
# Act
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[
|
||||
{"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}},
|
||||
{"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}},
|
||||
],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert existing_config_1.name == "updated-name"
|
||||
assert existing_config_1.enabled is False
|
||||
assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"}
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.delete.assert_called_once_with(existing_config_2)
|
||||
assert mock_db.session.commit.call_count >= 3
|
||||
mock_clear_cache.assert_any_call("tenant-1", "cfg-1")
|
||||
mock_clear_cache.assert_any_call("tenant-1", "cfg-2")
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config name"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "new", "enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
|
||||
|
||||
# Act
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"credential_id": "cred-1", "enabled": True}],
|
||||
"predefined-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_config = mock_db.session.add.call_args.args[0]
|
||||
assert created_config.name == "Main Credential"
|
||||
assert created_config.credential_id == "cred-1"
|
||||
assert created_config.credential_source_type == "provider"
|
||||
assert created_config.encrypted_config == '{"api_key":"enc"}'
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config = SimpleNamespace(id="cfg-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
|
||||
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
|
||||
|
||||
# Act
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_validate.call_count == 2
|
||||
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
|
||||
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
load_balancing_model_config = _load_balancing_model_config(
|
||||
encrypted_config=json.dumps({"api_key": "old-encrypted-token"})
|
||||
)
|
||||
mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value")
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": HIDDEN_VALUE, "region": "us"},
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:old-plain-value", "region": "us"}
|
||||
mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value")
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema())
|
||||
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
validate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:validated"}
|
||||
mock_factory.model_credentials_validate.assert_called_once()
|
||||
mock_factory.provider_credentials_validate.assert_not_called()
|
||||
mock_encrypt.assert_called_once_with("tenant-1", "validated")
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
validate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:provider-validated"}
|
||||
mock_factory.provider_credentials_validate.assert_called_once()
|
||||
mock_factory.model_credentials_validate.assert_not_called()
|
||||
|
||||
|
||||
def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
model_schema = _build_model_credential_schema()
|
||||
provider_schema = _build_provider_credential_schema()
|
||||
provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema)
|
||||
provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema)
|
||||
provider_configuration_without_schema = _build_provider_configuration()
|
||||
|
||||
# Act
|
||||
schema_from_model = service._get_credential_schema(provider_configuration_with_model)
|
||||
schema_from_provider = service._get_credential_schema(provider_configuration_with_provider)
|
||||
|
||||
# Assert
|
||||
assert schema_from_model is model_schema
|
||||
assert schema_from_provider is provider_schema
|
||||
with pytest.raises(ValueError, match="No credential schema found"):
|
||||
service._get_credential_schema(provider_configuration_without_schema)
|
||||
|
||||
|
||||
def test_clear_credentials_cache_should_delete_load_balancing_cache_entry(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_cls = mocker.patch(
|
||||
"services.model_load_balancing_service.ProviderCredentialsCache",
|
||||
return_value=mock_cache_instance,
|
||||
)
|
||||
|
||||
# Act
|
||||
service._clear_credentials_cache("tenant-1", "cfg-1")
|
||||
|
||||
# Assert
|
||||
mock_cache_cls.assert_called_once()
|
||||
assert mock_cache_cls.call_args.kwargs == {
|
||||
"tenant_id": "tenant-1",
|
||||
"identity_id": "cfg-1",
|
||||
"cache_type": mocker.ANY,
|
||||
}
|
||||
assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL"
|
||||
mock_cache_instance.delete.assert_called_once()
|
||||
224
api/tests/unit_tests/services/test_oauth_server_service.py
Normal file
224
api/tests/unit_tests/services/test_oauth_server_service.py
Normal file
@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
1249
api/tests/unit_tests/services/test_trigger_provider_service.py
Normal file
1249
api/tests/unit_tests/services/test_trigger_provider_service.py
Normal file
File diff suppressed because it is too large
Load Diff
259
api/tests/unit_tests/services/test_web_conversation_service.py
Normal file
259
api/tests/unit_tests/services/test_web_conversation_service.py
Normal file
@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.model import App, EndUser
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
return cast(App, SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=None,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_user = _account(id="user-1")
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id="conv-9",
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
assert call_kwargs["last_id"] == "conv-9"
|
||||
assert call_kwargs["sort_by"] == "-updated_at"
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
|
||||
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-1"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
session.scalars.return_value.all.return_value = ["conv-3"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] == ["conv-3"]
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_conversation_is_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_not_called()
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-2"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_conversation = SimpleNamespace(id="conv-2")
|
||||
mock_get_conversation = mocker.patch(
|
||||
"services.web_conversation_service.ConversationService.get_conversation",
|
||||
return_value=mock_conversation,
|
||||
)
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-2", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
|
||||
added_obj = mock_db.session.add.call_args.args[0]
|
||||
assert added_obj.app_id == "app-1"
|
||||
assert added_obj.conversation_id == "conv-2"
|
||||
assert added_obj.created_by_role == "account"
|
||||
assert added_obj.created_by == "account-2"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_conversation_is_not_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-3"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-7", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_delete_pinned_conversation_when_exists(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-4"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
pinned_obj = SimpleNamespace(id="pin-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-8", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_called_once_with(pinned_obj)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
379
api/tests/unit_tests/services/test_webapp_auth_service.py
Normal file
379
api/tests/unit_tests/services/test_webapp_auth_service.py
Normal file
@ -0,0 +1,379 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from models import Account, AccountStatus
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
|
||||
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
|
||||
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.webapp_auth_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountLoginError, match="Account is banned"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("password_value", [None, "hash"])
|
||||
def test_authenticate_should_raise_password_error_when_password_is_invalid(
|
||||
password_value: str | None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="u@example.com")
|
||||
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.login(account)
|
||||
|
||||
# Assert
|
||||
assert result == "jwt-token"
|
||||
mock_get_token.assert_called_once_with(account=account)
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("missing@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(Unauthorized, match="Account is banned"):
|
||||
WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Email must be provided"):
|
||||
WebAppAuthService.send_email_code_login_email(account=None, email=None)
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = _account(email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
|
||||
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
|
||||
|
||||
# Assert
|
||||
assert result == "token-1"
|
||||
mock_generate_token.assert_called_once()
|
||||
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
|
||||
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
|
||||
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
|
||||
|
||||
# Assert
|
||||
assert result == "token-2"
|
||||
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
|
||||
|
||||
|
||||
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_email_code_login_data("token-abc")
|
||||
|
||||
# Assert
|
||||
assert result == {"code": "123"}
|
||||
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
|
||||
|
||||
|
||||
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
|
||||
|
||||
# Act
|
||||
WebAppAuthService.revoke_email_code_login_token("token-xyz")
|
||||
|
||||
# Assert
|
||||
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="Site not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_query = MagicMock()
|
||||
app_query.where.return_value.first.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="App not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.app_id == "app-1"
|
||||
assert result.session_id == "user@example.com"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
|
||||
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
|
||||
|
||||
# Act
|
||||
token = WebAppAuthService._get_account_jwt_token(account)
|
||||
|
||||
# Assert
|
||||
assert token == "jwt-1"
|
||||
payload = mock_issue.call_args.args[0]
|
||||
assert payload["user_id"] == "a1"
|
||||
assert payload["session_id"] == "user@example.com"
|
||||
assert payload["token_source"] == "webapp_login_token"
|
||||
assert payload["auth_type"] == "internal"
|
||||
assert payload["exp"] > int(datetime.now(UTC).timestamp())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("private", True),
|
||||
("private_all", True),
|
||||
("public", False),
|
||||
],
|
||||
)
|
||||
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
|
||||
access_mode: str,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
|
||||
WebAppAuthService.is_app_require_permission_check()
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="App ID could not be determined"):
|
||||
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="public"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("public", WebAppAuthType.PUBLIC),
|
||||
("private", WebAppAuthType.INTERNAL),
|
||||
("private_all", WebAppAuthType.INTERNAL),
|
||||
("sso_verified", WebAppAuthType.EXTERNAL),
|
||||
],
|
||||
)
|
||||
def test_get_app_auth_type_should_map_access_modes_correctly(
|
||||
access_mode: str,
|
||||
expected: WebAppAuthType,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private_all"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result == WebAppAuthType.INTERNAL
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
|
||||
WebAppAuthService.get_app_auth_type()
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Could not determine app authentication type"):
|
||||
WebAppAuthService.get_app_auth_type(access_mode="unknown")
|
||||
300
api/tests/unit_tests/services/test_workflow_app_service.py
Normal file
300
api/tests/unit_tests/services/test_workflow_app_service.py
Normal file
@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import App, WorkflowAppLog
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> WorkflowAppService:
|
||||
# Arrange
|
||||
return WorkflowAppService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
# Arrange
|
||||
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
|
||||
|
||||
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
|
||||
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
|
||||
# Arrange
|
||||
log = _workflow_app_log(id="log-1", status="succeeded")
|
||||
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
|
||||
|
||||
# Act
|
||||
details = view.details
|
||||
proxied_status = view.status
|
||||
|
||||
# Assert
|
||||
assert details == {"trigger_metadata": {"type": "plugin"}}
|
||||
assert proxied_status == "succeeded"
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
log_2 = SimpleNamespace(id="log-2")
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.return_value.all.return_value = [log_1, log_2]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=2,
|
||||
detail=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["page"] == 1
|
||||
assert result["limit"] == 2
|
||||
assert result["total"] == 3
|
||||
assert result["has_more"] is True
|
||||
assert len(result["data"]) == 2
|
||||
assert isinstance(result["data"][0], LogView)
|
||||
assert result["data"][0].details is None
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [1]
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
|
||||
mock_handle = mocker.patch.object(
|
||||
service,
|
||||
"handle_trigger_metadata",
|
||||
return_value={"type": "trigger_plugin", "icon": "url"},
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword="run-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_before=None,
|
||||
created_at_after=None,
|
||||
page=1,
|
||||
limit=20,
|
||||
detail=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 1
|
||||
assert len(result["data"]) == 1
|
||||
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Account not found: account@example.com"):
|
||||
service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert result["data"] == []
|
||||
|
||||
|
||||
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_account = SimpleNamespace(
|
||||
id="log-1",
|
||||
created_by="acc-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
workflow_run_summary={"run": "1"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-01",
|
||||
)
|
||||
log_end_user = SimpleNamespace(
|
||||
id="log-2",
|
||||
created_by="end-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
workflow_run_summary={"run": "2"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-02",
|
||||
)
|
||||
log_unknown = SimpleNamespace(
|
||||
id="log-3",
|
||||
created_by="other",
|
||||
created_by_role="system",
|
||||
workflow_run_summary={"run": "3"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-03",
|
||||
)
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.side_effect = [
|
||||
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert result["data"][0]["created_by_account"].id == "acc-1"
|
||||
assert result["data"][1]["created_by_end_user"].id == "end-1"
|
||||
assert result["data"][2]["created_by_account"] is None
|
||||
assert result["data"][2]["created_by_end_user"] is None
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", None)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {
|
||||
"type": AppTriggerType.TRIGGER_PLUGIN.value,
|
||||
"icon_filename": "light.png",
|
||||
"icon_dark_filename": "dark.png",
|
||||
}
|
||||
mock_icon = mocker.patch(
|
||||
"services.workflow_app_service.PluginService.get_plugin_icon_url",
|
||||
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["icon"] == "https://cdn/light.png"
|
||||
assert result["icon_dark"] == "https://cdn/dark.png"
|
||||
assert mock_icon.call_count == 2
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
|
||||
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
|
||||
mock_icon.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
('{"k":"v"}', {"k": "v"}),
|
||||
("not-json", None),
|
||||
({"raw": True}, {"raw": True}),
|
||||
],
|
||||
)
|
||||
def test_safe_json_loads_should_handle_various_inputs(
|
||||
value: object,
|
||||
expected: object,
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service._safe_json_loads(value)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
short_result = service._safe_parse_uuid("short")
|
||||
invalid_result = service._safe_parse_uuid("x" * 40)
|
||||
|
||||
# Assert
|
||||
assert short_result is None
|
||||
assert invalid_result is None
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
raw_uuid = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
result = service._safe_parse_uuid(raw_uuid)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert str(result) == raw_uuid
|
||||
File diff suppressed because it is too large
Load Diff
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants used throughout the tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TENANT_ID = "tenant-abc"
|
||||
ACCOUNT_ID = "account-xyz"
|
||||
FILES_BASE_URL = "https://files.example.com"
|
||||
|
||||
DB_PATH = "services.workspace_service.db"
|
||||
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
|
||||
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
|
||||
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
|
||||
CURRENT_USER_PATH = "services.workspace_service.current_user"
|
||||
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tenant(
|
||||
tenant_id: str = TENANT_ID,
|
||||
name: str = "My Workspace",
|
||||
plan: str = "sandbox",
|
||||
status: str = "active",
|
||||
custom_config: dict | None = None,
|
||||
) -> Tenant:
|
||||
"""Create a minimal Tenant-like namespace."""
|
||||
return cast(
|
||||
Tenant,
|
||||
SimpleNamespace(
|
||||
id=tenant_id,
|
||||
name=name,
|
||||
plan=plan,
|
||||
status=status,
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
custom_config_dict=custom_config or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_feature(
|
||||
can_replace_logo: bool = False,
|
||||
next_credit_reset_date: str | None = None,
|
||||
billing_plan: str = "sandbox",
|
||||
) -> MagicMock:
|
||||
"""Create a feature namespace matching what FeatureService.get_features returns."""
|
||||
feature = MagicMock()
|
||||
feature.can_replace_logo = can_replace_logo
|
||||
feature.next_credit_reset_date = next_credit_reset_date
|
||||
feature.billing.subscription.plan = billing_plan
|
||||
return feature
|
||||
|
||||
|
||||
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
|
||||
pool = MagicMock()
|
||||
pool.quota_limit = quota_limit
|
||||
pool.quota_used = quota_used
|
||||
return pool
|
||||
|
||||
|
||||
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(role=role)
|
||||
|
||||
|
||||
def _tenant_info(result: object) -> dict[str, Any] | None:
|
||||
return cast(dict[str, Any] | None, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user() -> SimpleNamespace:
|
||||
"""Return a lightweight current_user stand-in."""
|
||||
return SimpleNamespace(id=ACCOUNT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""
|
||||
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
|
||||
|
||||
Returns a dict of named mocks so individual tests can customise them.
|
||||
"""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
|
||||
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "SELF_HOSTED"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"has_roles": mock_has_roles,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. None Tenant Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
|
||||
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = None
|
||||
|
||||
# Act
|
||||
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
|
||||
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange / Act / Assert
|
||||
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Basic Tenant Info — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_base_fields(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""get_tenant_info should always return the six base scalar fields."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["id"] == TENANT_ID
|
||||
assert result["name"] == "My Workspace"
|
||||
assert result["plan"] == "sandbox"
|
||||
assert result["status"] == "active"
|
||||
assert result["created_at"] == "2024-01-01T00:00:00Z"
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["role"] == "admin"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The service asserts that TenantAccountJoin exists.
|
||||
Missing join should raise AssertionError.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = None
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Logo Customisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(
|
||||
custom_config={
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is True
|
||||
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
|
||||
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config should be absent when can_replace_logo is False."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block is gated on OWNER or ADMIN role."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = False # regular member
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
basic_mocks["config"].FILES_URL = custom_base
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Cloud-Edition Credit Features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""Patches for CLOUD edition tests, billing plan = professional by default."""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(
|
||||
FEATURE_SERVICE_PATH,
|
||||
return_value=_make_feature(
|
||||
can_replace_logo=False,
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
|
||||
),
|
||||
)
|
||||
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "CLOUD"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
CREDIT_POOL_SERVICE_PATH,
|
||||
side_effect=[None, None], # both paid and trial pools absent
|
||||
)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
|
||||
trial_pool = _make_pool(quota_limit=100, quota_used=10)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
trial_pool = _make_pool(quota_limit=50, quota_used=5)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
When the subscription plan IS SANDBOX, the paid pool branch is skipped
|
||||
entirely and we fall back to the trial pool.
|
||||
"""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange — override billing plan to SANDBOX
|
||||
cloud_mocks["get_features"].return_value = _make_feature(
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CloudPlan.SANDBOX,
|
||||
)
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
|
||||
trial_pool = _make_pool(quota_limit=200, quota_used=20)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
|
||||
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Self-hosted / Non-Cloud Edition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DB query integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The DB query for TenantAccountJoin must be scoped to the correct
|
||||
tenant_id and current_user.id.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant(tenant_id="my-special-tenant")
|
||||
mock_current_user = mocker.patch(CURRENT_USER_PATH)
|
||||
mock_current_user.id = "special-user-id"
|
||||
|
||||
# Act
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert — db.session.query was invoked (at least once)
|
||||
basic_mocks["db_session"].query.assert_called()
|
||||
@ -0,0 +1,643 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
|
||||
return SimpleNamespace(operation_id=operation_id)
|
||||
|
||||
|
||||
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.parser_api_schema("valid-schema")
|
||||
|
||||
# Assert
|
||||
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
|
||||
assert len(result["credentials_schema"]) == 3
|
||||
assert "warning" in result
|
||||
|
||||
|
||||
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("bad schema"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
|
||||
ApiToolManageService.parser_api_schema("invalid")
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=expected,
|
||||
)
|
||||
extra_info: dict[str, str] = {}
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=ValueError("parse failed"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: parse failed"):
|
||||
ApiToolManageService.convert_schema_to_tool_bundles("schema")
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="provider provider-a already exists"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name=" provider-a ",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
many_tools = [_tool_bundle(str(i)) for i in range(101)]
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=mock_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_controller.load_bundled_tools.assert_called_once()
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=200, text="schema-content"),
|
||||
)
|
||||
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
|
||||
# Assert
|
||||
assert result == {"schema": "schema-content"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 404, 500])
|
||||
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
|
||||
status_code: int,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
|
||||
)
|
||||
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
|
||||
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
return_value=controller,
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
|
||||
assert mock_convert.call_count == 2
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(credentials={}, name="old")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(
|
||||
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
|
||||
name="old",
|
||||
icon="",
|
||||
schema="",
|
||||
description="",
|
||||
schema_type_str="",
|
||||
tools_str="",
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
credentials_str="",
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=controller,
|
||||
)
|
||||
cache = MagicMock()
|
||||
encrypter = MagicMock()
|
||||
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
|
||||
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
|
||||
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(encrypter, cache),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-new",
|
||||
original_provider="provider-old",
|
||||
icon={"emoji": "E"},
|
||||
credentials={"auth_type": "none", "api_key_value": "***"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
assert provider.name == "provider-new"
|
||||
assert provider.privacy_policy == "privacy"
|
||||
assert provider.credentials_str != ""
|
||||
cache.delete.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
provider = object()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.delete.assert_called_once_with(provider)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = {"provider": "value"}
|
||||
mock_get = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
|
||||
return_value=expected,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
|
||||
# Arrange
|
||||
schema_type = "bad-schema-type"
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=schema_type, # type: ignore[arg-type]
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("invalid"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid tool name tool-b"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-b",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "validation failed"}
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.return_value = {"ok": True}
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={"x": "1"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": {"ok": True}}
|
||||
|
||||
|
||||
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_one = SimpleNamespace(name="p1")
|
||||
provider_two = SimpleNamespace(name="p2")
|
||||
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
|
||||
|
||||
controller_one = MagicMock()
|
||||
controller_one.get_tools.return_value = ["tool-a"]
|
||||
controller_two = MagicMock()
|
||||
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
|
||||
|
||||
user_provider_one = SimpleNamespace(labels=[], tools=[])
|
||||
user_provider_two = SimpleNamespace(labels=[], tools=[])
|
||||
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
side_effect=[controller_one, controller_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
|
||||
side_effect=[user_provider_one, user_provider_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tools("tenant-1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert user_provider_one.tools == [{"name": "tool-a"}]
|
||||
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
|
||||
assert mock_convert.call_count == 3
|
||||
1045
api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py
Normal file
1045
api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user