mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge commit '9c339239' into sandboxed-agent-rebase
Made-with: Cursor # Conflicts: # api/README.md # api/controllers/console/app/workflow_draft_variable.py # api/core/agent/cot_agent_runner.py # api/core/agent/fc_agent_runner.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/plugin/backwards_invocation/model.py # api/core/prompt/advanced_prompt_transform.py # api/core/workflow/nodes/base/node.py # api/core/workflow/nodes/llm/llm_utils.py # api/core/workflow/nodes/llm/node.py # api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py # api/core/workflow/nodes/question_classifier/question_classifier_node.py # api/core/workflow/runtime/graph_runtime_state.py # api/extensions/storage/base_storage.py # api/factories/variable_factory.py # api/pyproject.toml # api/services/variable_truncator.py # api/uv.lock # web/app/account/oauth/authorize/page.tsx # web/app/components/app/configuration/config-var/config-modal/field.tsx # web/app/components/base/alert.tsx # web/app/components/base/chat/chat/answer/human-input-content/executed-action.tsx # web/app/components/base/chat/chat/answer/more.tsx # web/app/components/base/chat/chat/answer/operation.tsx # web/app/components/base/chat/chat/answer/workflow-process.tsx # web/app/components/base/chat/chat/citation/index.tsx # web/app/components/base/chat/chat/citation/popup.tsx # web/app/components/base/chat/chat/citation/progress-tooltip.tsx # web/app/components/base/chat/chat/citation/tooltip.tsx # web/app/components/base/chat/chat/question.tsx # web/app/components/base/chat/embedded-chatbot/inputs-form/index.tsx # web/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx # web/app/components/base/markdown-blocks/form.tsx # web/app/components/base/prompt-editor/plugins/hitl-input-block/component-ui.tsx # web/app/components/base/tag-management/panel.tsx # web/app/components/base/tag-management/trigger.tsx # web/app/components/header/account-setting/index.tsx # web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx # web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx # web/app/signin/utils/post-login-redirect.ts # web/eslint-suppressions.json # web/package.json # web/pnpm-lock.yaml
This commit is contained in:
@ -19,7 +19,7 @@ class TestApiKeyAuthFactory:
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
"""Test getting auth factory for all valid providers"""
|
||||
with patch(auth_class_path) as mock_auth:
|
||||
with patch(auth_class_path, autospec=True) as mock_auth:
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
|
||||
assert auth_class == mock_auth
|
||||
|
||||
@ -46,7 +46,7 @@ class TestApiKeyAuthFactory:
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
|
||||
def test_validate_credentials_delegates_to_auth_instance(
|
||||
self, mock_get_factory, credentials_return_value, expected_result
|
||||
):
|
||||
@ -65,7 +65,7 @@ class TestApiKeyAuthFactory:
|
||||
assert result is expected_result
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
|
||||
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
|
||||
"""Test that exceptions from auth instance are propagated"""
|
||||
# Arrange
|
||||
|
||||
@ -65,7 +65,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -96,7 +96,7 @@ class TestFirecrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -118,7 +118,7 @@ class TestFirecrawlAuth:
|
||||
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -145,7 +145,7 @@ class TestFirecrawlAuth:
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
@ -167,7 +167,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation and normalized"""
|
||||
mock_response = MagicMock()
|
||||
@ -185,7 +185,7 @@ class TestFirecrawlAuth:
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
@ -35,7 +35,7 @@ class TestJinaAuth:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -53,7 +53,7 @@ class TestJinaAuth:
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
@ -68,7 +68,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
@ -83,7 +83,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
@ -98,7 +98,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -114,7 +114,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -130,7 +130,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -132,7 +132,7 @@ class TestWatercrawlAuth:
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
@ -193,7 +193,7 @@ class TestWatercrawlAuth:
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
@ -1,932 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.
|
||||
|
||||
This module contains extensive unit tests for the DatasetCollectionBindingService class,
|
||||
which handles dataset collection binding operations for vector database collections.
|
||||
|
||||
The DatasetCollectionBindingService provides methods for:
|
||||
- Retrieving or creating dataset collection bindings by provider, model, and type
|
||||
- Retrieving specific collection bindings by ID and type
|
||||
- Managing collection bindings for different collection types (dataset, etc.)
|
||||
|
||||
Collection bindings are used to map embedding models (provider + model name) to
|
||||
specific vector database collections, allowing datasets to share collections when
|
||||
they use the same embedding model configuration.
|
||||
|
||||
This test suite ensures:
|
||||
- Correct retrieval of existing bindings
|
||||
- Proper creation of new bindings when they don't exist
|
||||
- Accurate filtering by provider, model, and collection type
|
||||
- Proper error handling for missing bindings
|
||||
- Database transaction handling (add, commit)
|
||||
- Collection name generation using Dataset.gen_collection_name_by_id
|
||||
|
||||
================================================================================
|
||||
ARCHITECTURE OVERVIEW
|
||||
================================================================================
|
||||
|
||||
The DatasetCollectionBindingService is a critical component in the Dify platform's
|
||||
vector database management system. It serves as an abstraction layer between the
|
||||
application logic and the underlying vector database collections.
|
||||
|
||||
Key Concepts:
|
||||
1. Collection Binding: A mapping between an embedding model configuration
|
||||
(provider + model name) and a vector database collection name. This allows
|
||||
multiple datasets to share the same collection when they use identical
|
||||
embedding models, improving resource efficiency.
|
||||
|
||||
2. Collection Type: Different types of collections can exist (e.g., "dataset",
|
||||
"custom_type"). This allows for separation of collections based on their
|
||||
intended use case or data structure.
|
||||
|
||||
3. Provider and Model: The combination of provider_name (e.g., "openai",
|
||||
"cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
|
||||
uniquely identifies an embedding model configuration.
|
||||
|
||||
4. Collection Name Generation: When a new binding is created, a unique collection
|
||||
name is generated using Dataset.gen_collection_name_by_id() with a UUID.
|
||||
This ensures each binding has a unique collection identifier.
|
||||
|
||||
================================================================================
|
||||
TESTING STRATEGY
|
||||
================================================================================
|
||||
|
||||
This test suite follows a comprehensive testing strategy that covers:
|
||||
|
||||
1. Happy Path Scenarios:
|
||||
- Successful retrieval of existing bindings
|
||||
- Successful creation of new bindings
|
||||
- Proper handling of default parameters
|
||||
|
||||
2. Edge Cases:
|
||||
- Different collection types
|
||||
- Various provider/model combinations
|
||||
- Default vs explicit parameter usage
|
||||
|
||||
3. Error Handling:
|
||||
- Missing bindings (for get_by_id_and_type)
|
||||
- Database query failures
|
||||
- Invalid parameter combinations
|
||||
|
||||
4. Database Interaction:
|
||||
- Query construction and execution
|
||||
- Transaction management (add, commit)
|
||||
- Query chaining (where, order_by, first)
|
||||
|
||||
5. Mocking Strategy:
|
||||
- Database session mocking
|
||||
- Query builder chain mocking
|
||||
- UUID generation mocking
|
||||
- Collection name generation mocking
|
||||
|
||||
================================================================================
|
||||
"""
|
||||
|
||||
"""
|
||||
Import statements for the test module.
|
||||
|
||||
This section imports all necessary dependencies for testing the
|
||||
DatasetCollectionBindingService, including:
|
||||
- unittest.mock for creating mock objects
|
||||
- pytest for test framework functionality
|
||||
- uuid for UUID generation (used in collection name generation)
|
||||
- Models and services from the application codebase
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
# ============================================================================
|
||||
# The Test Data Factory pattern is used here to centralize the creation of
|
||||
# test objects and mock instances. This approach provides several benefits:
|
||||
#
|
||||
# 1. Consistency: All test objects are created using the same factory methods,
|
||||
# ensuring consistent structure across all tests.
|
||||
#
|
||||
# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
|
||||
# changes, we only need to update the factory methods rather than every
|
||||
# individual test.
|
||||
#
|
||||
# 3. Reusability: Factory methods can be reused across multiple test classes,
|
||||
# reducing code duplication.
|
||||
#
|
||||
# 4. Readability: Tests become more readable when they use descriptive factory
|
||||
# method calls instead of complex object construction logic.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class DatasetCollectionBindingTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for dataset collection binding tests.
|
||||
|
||||
This factory provides static methods to create mock objects for:
|
||||
- DatasetCollectionBinding instances
|
||||
- Database query results
|
||||
- Collection name generation results
|
||||
|
||||
The factory methods help maintain consistency across tests and reduce
|
||||
code duplication when setting up test scenarios.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding_mock(
|
||||
binding_id: str = "binding-123",
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
collection_name: str = "collection-abc",
|
||||
collection_type: str = "dataset",
|
||||
created_at=None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock DatasetCollectionBinding with specified attributes.
|
||||
|
||||
Args:
|
||||
binding_id: Unique identifier for the binding
|
||||
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
|
||||
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
|
||||
collection_name: Name of the vector database collection
|
||||
collection_type: Type of collection (default: "dataset")
|
||||
created_at: Optional datetime for creation timestamp
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a DatasetCollectionBinding instance
|
||||
"""
|
||||
binding = Mock(spec=DatasetCollectionBinding)
|
||||
binding.id = binding_id
|
||||
binding.provider_name = provider_name
|
||||
binding.model_name = model_name
|
||||
binding.collection_name = collection_name
|
||||
binding.type = collection_type
|
||||
binding.created_at = created_at
|
||||
for key, value in kwargs.items():
|
||||
setattr(binding, key, value)
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset for testing collection name generation.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Dataset instance
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for get_dataset_collection_binding
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetCollectionBindingServiceGetBinding:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
|
||||
|
||||
This test class covers the main collection binding retrieval/creation functionality,
|
||||
including various provider/model combinations, collection types, and edge cases.
|
||||
|
||||
The get_dataset_collection_binding method:
|
||||
1. Queries for existing binding by provider_name, model_name, and collection_type
|
||||
2. Orders results by created_at (ascending) and takes the first match
|
||||
3. If no binding exists, creates a new one with:
|
||||
- The provided provider_name and model_name
|
||||
- A generated collection_name using Dataset.gen_collection_name_by_id
|
||||
- The provided collection_type
|
||||
4. Adds the new binding to the database session and commits
|
||||
5. Returns the binding (either existing or newly created)
|
||||
|
||||
Test scenarios include:
|
||||
- Retrieving existing bindings
|
||||
- Creating new bindings when none exist
|
||||
- Different collection types
|
||||
- Database transaction handling
|
||||
- Collection name generation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing database operations.
|
||||
|
||||
Provides a mocked database session that can be used to verify:
|
||||
- Query construction and execution
|
||||
- Add operations for new bindings
|
||||
- Commit operations for transaction completion
|
||||
|
||||
The mock is configured to return a query builder that supports
|
||||
chaining operations like .where(), .order_by(), and .first().
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval of an existing collection binding.
|
||||
|
||||
Verifies that when a binding already exists in the database for the given
|
||||
provider, model, and collection type, the method returns the existing binding
|
||||
without creating a new one.
|
||||
|
||||
This test ensures:
|
||||
- The query is constructed correctly with all three filters
|
||||
- Results are ordered by created_at
|
||||
- The first matching binding is returned
|
||||
- No new binding is created (db.session.add is not called)
|
||||
- No commit is performed (db.session.commit is not called)
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-123",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain: query().where().order_by().first()
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == "binding-123"
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed correctly
|
||||
# The query should be constructed with DatasetCollectionBinding as the model
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied to filter by provider, model, and type
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Verify the results were ordered by created_at (ascending)
|
||||
# This ensures we get the oldest binding if multiple exist
|
||||
mock_where.order_by.assert_called_once()
|
||||
|
||||
# Verify no new binding was created
|
||||
# Since an existing binding was found, we should not create a new one
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
# Verify no commit was performed
|
||||
# Since no new binding was created, no database transaction is needed
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful creation of a new collection binding when none exists.
|
||||
|
||||
Verifies that when no binding exists in the database for the given
|
||||
provider, model, and collection type, the method creates a new binding
|
||||
with a generated collection name and commits it to the database.
|
||||
|
||||
This test ensures:
|
||||
- The query returns None (no existing binding)
|
||||
- A new DatasetCollectionBinding is created with correct attributes
|
||||
- Dataset.gen_collection_name_by_id is called to generate collection name
|
||||
- The new binding is added to the database session
|
||||
- The transaction is committed
|
||||
- The newly created binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "cohere"
|
||||
model_name = "embed-english-v3.0"
|
||||
collection_type = "dataset"
|
||||
generated_collection_name = "collection-generated-xyz"
|
||||
|
||||
# Mock the query chain to return None (no existing binding)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No existing binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock Dataset.gen_collection_name_by_id to return a generated name
|
||||
with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
|
||||
mock_gen_name.return_value = generated_collection_name
|
||||
|
||||
# Mock uuid.uuid4 for the collection name generation
|
||||
mock_uuid = "test-uuid-123"
|
||||
with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
assert result.type == collection_type
|
||||
assert result.collection_name == generated_collection_name
|
||||
|
||||
# Verify Dataset.gen_collection_name_by_id was called with the generated UUID
|
||||
# This method generates a unique collection name based on the UUID
|
||||
# The UUID is converted to string before passing to the method
|
||||
mock_gen_name.assert_called_once_with(str(mock_uuid))
|
||||
|
||||
# Verify new binding was added to the database session
|
||||
# The add method should be called exactly once with the new binding instance
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
# Extract the binding that was added to verify its properties
|
||||
added_binding = mock_db_session.add.call_args[0][0]
|
||||
|
||||
# Verify the added binding is an instance of DatasetCollectionBinding
|
||||
# This ensures we're creating the correct type of object
|
||||
assert isinstance(added_binding, DatasetCollectionBinding)
|
||||
|
||||
# Verify all the binding properties are set correctly
|
||||
# These should match the input parameters to the method
|
||||
assert added_binding.provider_name == provider_name
|
||||
assert added_binding.model_name == model_name
|
||||
assert added_binding.type == collection_type
|
||||
|
||||
# Verify the collection name was set from the generated name
|
||||
# This ensures the binding has a valid collection identifier
|
||||
assert added_binding.collection_name == generated_collection_name
|
||||
|
||||
# Verify the transaction was committed
|
||||
# This ensures the new binding is persisted to the database
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with a different collection type (not "dataset").
|
||||
|
||||
Verifies that the method correctly filters by collection_type, allowing
|
||||
different types of collections to coexist with the same provider/model
|
||||
combination.
|
||||
|
||||
This test ensures:
|
||||
- Collection type is properly used as a filter in the query
|
||||
- Different collection types can have separate bindings
|
||||
- The correct binding is returned based on type
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "custom_type"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-456",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed with the correct type filter
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with default collection type ("dataset").
|
||||
|
||||
Verifies that when collection_type is not provided, it defaults to "dataset"
|
||||
as specified in the method signature.
|
||||
|
||||
This test ensures:
|
||||
- The default value "dataset" is used when type is not specified
|
||||
- The query correctly filters by the default type
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
# collection_type defaults to "dataset" in method signature
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-789",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type="dataset", # Default type
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act - call without specifying collection_type (uses default)
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.type == "dataset"
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with different provider/model combinations.
|
||||
|
||||
Verifies that bindings are correctly filtered by both provider_name and
|
||||
model_name, ensuring that different model combinations have separate bindings.
|
||||
|
||||
This test ensures:
|
||||
- Provider and model are both used as filters
|
||||
- Different combinations result in different bindings
|
||||
- The correct binding is returned for each combination
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "huggingface"
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-hf-123",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
# Verify query filters were applied correctly
|
||||
# The query should filter by both provider_name and model_name
|
||||
# This ensures different model combinations have separate bindings
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied with all three filters:
|
||||
# - provider_name filter
|
||||
# - model_name filter
|
||||
# - collection_type filter
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for get_dataset_collection_binding_by_id_and_type
|
||||
# ============================================================================
|
||||
# This section contains tests for the get_dataset_collection_binding_by_id_and_type
|
||||
# method, which retrieves a specific collection binding by its ID and type.
|
||||
#
|
||||
# Key differences from get_dataset_collection_binding:
|
||||
# 1. This method queries by ID and type, not by provider/model/type
|
||||
# 2. This method does NOT create a new binding if one doesn't exist
|
||||
# 3. This method raises ValueError if the binding is not found
|
||||
# 4. This method is typically used when you already know the binding ID
|
||||
#
|
||||
# Use cases:
|
||||
# - Retrieving a binding that was previously created
|
||||
# - Validating that a binding exists before using it
|
||||
# - Accessing binding metadata when you have the ID
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
|
||||
|
||||
This test class covers collection binding retrieval by ID and type,
|
||||
including success scenarios and error handling for missing bindings.
|
||||
|
||||
The get_dataset_collection_binding_by_id_and_type method:
|
||||
1. Queries for a binding by collection_binding_id and collection_type
|
||||
2. Orders results by created_at (ascending) and takes the first match
|
||||
3. If no binding exists, raises ValueError("Dataset collection binding not found")
|
||||
4. Returns the found binding
|
||||
|
||||
Unlike get_dataset_collection_binding, this method does NOT create a new
|
||||
binding if one doesn't exist - it only retrieves existing bindings.
|
||||
|
||||
Test scenarios include:
|
||||
- Successful retrieval of existing bindings
|
||||
- Error handling for missing bindings
|
||||
- Different collection types
|
||||
- Default collection type behavior
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing database operations.
|
||||
|
||||
Provides a mocked database session that can be used to verify:
|
||||
- Query construction with ID and type filters
|
||||
- Ordering by created_at
|
||||
- First result retrieval
|
||||
|
||||
The mock is configured to return a query builder that supports
|
||||
chaining operations like .where(), .order_by(), and .first().
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval of a collection binding by ID and type.
|
||||
|
||||
Verifies that when a binding exists in the database with the given
|
||||
ID and collection type, the method returns the binding.
|
||||
|
||||
This test ensures:
|
||||
- The query is constructed correctly with ID and type filters
|
||||
- Results are ordered by created_at
|
||||
- The first matching binding is returned
|
||||
- No error is raised
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-123"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain: query().where().order_by().first()
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_where.order_by.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
|
||||
"""
|
||||
Test error handling when binding is not found.
|
||||
|
||||
Verifies that when no binding exists in the database with the given
|
||||
ID and collection type, the method raises a ValueError with the
|
||||
message "Dataset collection binding not found".
|
||||
|
||||
This test ensures:
|
||||
- The query returns None (no existing binding)
|
||||
- ValueError is raised with the correct message
|
||||
- No binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "non-existent-binding"
|
||||
collection_type = "dataset"
|
||||
|
||||
# Mock the query chain to return None (no existing binding)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No existing binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Verify query was attempted
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with a different collection type.
|
||||
|
||||
Verifies that the method correctly filters by collection_type, ensuring
|
||||
that bindings with the same ID but different types are treated as
|
||||
separate entities.
|
||||
|
||||
This test ensures:
|
||||
- Collection type is properly used as a filter in the query
|
||||
- Different collection types can have separate bindings with same ID
|
||||
- The correct binding is returned based on type
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-456"
|
||||
collection_type = "custom_type"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="cohere",
|
||||
model_name="embed-english-v3.0",
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed with the correct type filter
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with default collection type ("dataset").
|
||||
|
||||
Verifies that when collection_type is not provided, it defaults to "dataset"
|
||||
as specified in the method signature.
|
||||
|
||||
This test ensures:
|
||||
- The default value "dataset" is used when type is not specified
|
||||
- The query correctly filters by the default type
|
||||
- The correct binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-789"
|
||||
# collection_type defaults to "dataset" in method signature
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_type="dataset", # Default type
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act - call without specifying collection_type (uses default)
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == "dataset"
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
|
||||
"""
|
||||
Test error handling when binding exists but with wrong collection type.
|
||||
|
||||
Verifies that when a binding exists with the given ID but a different
|
||||
collection type, the method raises a ValueError because the binding
|
||||
doesn't match both the ID and type criteria.
|
||||
|
||||
This test ensures:
|
||||
- The query correctly filters by both ID and type
|
||||
- Bindings with matching ID but different type are not returned
|
||||
- ValueError is raised when no matching binding is found
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-123"
|
||||
collection_type = "dataset"
|
||||
|
||||
# Mock the query chain to return None (binding exists but with different type)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No matching binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Verify query was attempted with both ID and type filters
|
||||
# The query should filter by both collection_binding_id and collection_type
|
||||
# This ensures we only get bindings that match both criteria
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied with both filters:
|
||||
# - collection_binding_id filter (exact match)
|
||||
# - collection_type filter (exact match)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Note: The order_by and first() calls are also part of the query chain,
|
||||
# but we don't need to verify them separately since they're part of the
|
||||
# standard query pattern used by both methods in this service.
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Additional Test Scenarios and Edge Cases
|
||||
# ============================================================================
|
||||
# The following section could contain additional test scenarios if needed:
|
||||
#
|
||||
# Potential additional tests:
|
||||
# 1. Test with multiple existing bindings (verify ordering by created_at)
|
||||
# 2. Test with very long provider/model names (boundary testing)
|
||||
# 3. Test with special characters in provider/model names
|
||||
# 4. Test concurrent binding creation (thread safety)
|
||||
# 5. Test database rollback scenarios
|
||||
# 6. Test with None values for optional parameters
|
||||
# 7. Test with empty strings for required parameters
|
||||
# 8. Test collection name generation uniqueness
|
||||
# 9. Test with different UUID formats
|
||||
# 10. Test query performance with large datasets
|
||||
#
|
||||
# These scenarios are not currently implemented but could be added if needed
|
||||
# based on real-world usage patterns or discovered edge cases.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Notes and Best Practices
|
||||
# ============================================================================
|
||||
#
|
||||
# When using DatasetCollectionBindingService in production code, consider:
|
||||
#
|
||||
# 1. Error Handling:
|
||||
# - Always handle ValueError exceptions when calling
|
||||
# get_dataset_collection_binding_by_id_and_type
|
||||
# - Check return values from get_dataset_collection_binding to ensure
|
||||
# bindings were created successfully
|
||||
#
|
||||
# 2. Performance Considerations:
|
||||
# - The service queries the database on every call, so consider caching
|
||||
# bindings if they're accessed frequently
|
||||
# - Collection bindings are typically long-lived, so caching is safe
|
||||
#
|
||||
# 3. Transaction Management:
|
||||
# - New bindings are automatically committed to the database
|
||||
# - If you need to rollback, ensure you're within a transaction context
|
||||
#
|
||||
# 4. Collection Type Usage:
|
||||
# - Use "dataset" for standard dataset collections
|
||||
# - Use custom types only when you need to separate collections by purpose
|
||||
# - Be consistent with collection type naming across your application
|
||||
#
|
||||
# 5. Provider and Model Naming:
|
||||
# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
|
||||
# - Use exact model names as provided by the model provider
|
||||
# - These names are case-sensitive and must match exactly
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Database Schema Reference
|
||||
# ============================================================================
|
||||
#
|
||||
# The DatasetCollectionBinding model has the following structure:
|
||||
#
|
||||
# - id: StringUUID (primary key, auto-generated)
|
||||
# - provider_name: String(255) (required, e.g., "openai", "cohere")
|
||||
# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
|
||||
# - type: String(40) (required, default: "dataset")
|
||||
# - collection_name: String(64) (required, unique collection identifier)
|
||||
# - created_at: DateTime (auto-generated timestamp)
|
||||
#
|
||||
# Indexes:
|
||||
# - Primary key on id
|
||||
# - Composite index on (provider_name, model_name) for efficient lookups
|
||||
#
|
||||
# Relationships:
|
||||
# - One binding can be referenced by multiple datasets
|
||||
# - Datasets reference bindings via collection_binding_id
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mocking Strategy Documentation
|
||||
# ============================================================================
|
||||
#
|
||||
# This test suite uses extensive mocking to isolate the unit under test.
|
||||
# Here's how the mocking strategy works:
|
||||
#
|
||||
# 1. Database Session Mocking:
|
||||
# - db.session is patched to prevent actual database access
|
||||
# - Query chains are mocked to return predictable results
|
||||
# - Add and commit operations are tracked for verification
|
||||
#
|
||||
# 2. Query Chain Mocking:
|
||||
# - query() returns a mock query object
|
||||
# - where() returns a mock where object
|
||||
# - order_by() returns a mock order_by object
|
||||
# - first() returns the final result (binding or None)
|
||||
#
|
||||
# 3. UUID Generation Mocking:
|
||||
# - uuid.uuid4() is mocked to return predictable UUIDs
|
||||
# - This ensures collection names are generated consistently in tests
|
||||
#
|
||||
# 4. Collection Name Generation Mocking:
|
||||
# - Dataset.gen_collection_name_by_id() is mocked
|
||||
# - This allows us to verify the method is called correctly
|
||||
# - We can control the generated collection name for testing
|
||||
#
|
||||
# Benefits of this approach:
|
||||
# - Tests run quickly (no database I/O)
|
||||
# - Tests are deterministic (no random UUIDs)
|
||||
# - Tests are isolated (no side effects)
|
||||
# - Tests are maintainable (clear mock setup)
|
||||
#
|
||||
# ============================================================================
|
||||
@ -96,7 +96,6 @@ from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -536,421 +535,6 @@ class TestDatasetServiceUpdateDataset:
|
||||
DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for delete_dataset
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceDeleteDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.delete_dataset method.
|
||||
|
||||
This test class covers the dataset deletion functionality, including
|
||||
permission validation, event signaling, and database cleanup.
|
||||
|
||||
The delete_dataset method:
|
||||
1. Retrieves the dataset by ID
|
||||
2. Returns False if dataset not found
|
||||
3. Validates user permissions
|
||||
4. Sends dataset_was_deleted event
|
||||
5. Deletes dataset from database
|
||||
6. Commits transaction
|
||||
7. Returns True on success
|
||||
|
||||
Test scenarios include:
|
||||
- Successful dataset deletion
|
||||
- Permission validation
|
||||
- Event signaling
|
||||
- Database cleanup
|
||||
- Not found handling
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""
|
||||
Mock dataset service dependencies for testing.
|
||||
|
||||
Provides mocked dependencies including:
|
||||
- get_dataset method
|
||||
- check_dataset_permission method
|
||||
- dataset_was_deleted event signal
|
||||
- Database session
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("services.dataset_service.dataset_was_deleted") as mock_event,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
):
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"dataset_was_deleted": mock_event,
|
||||
"db_session": mock_db,
|
||||
}
|
||||
|
||||
def test_delete_dataset_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of a dataset.
|
||||
|
||||
Verifies that when all validation passes, a dataset is deleted
|
||||
correctly with proper event signaling and database cleanup.
|
||||
|
||||
This test ensures:
|
||||
- Dataset is retrieved correctly
|
||||
- Permission is checked
|
||||
- Event is sent for cleanup
|
||||
- Dataset is deleted from database
|
||||
- Transaction is committed
|
||||
- Method returns True
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset_id, user)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify dataset was retrieved
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
|
||||
|
||||
# Verify permission was checked
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify event was sent for cleanup
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
|
||||
# Verify dataset was deleted and committed
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test handling when dataset is not found.
|
||||
|
||||
Verifies that when the dataset ID doesn't exist, the method
|
||||
returns False without performing any operations.
|
||||
|
||||
This test ensures:
|
||||
- Method returns False when dataset not found
|
||||
- No permission checks are performed
|
||||
- No events are sent
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "non-existent-dataset"
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset_id, user)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
# Verify no operations were performed
|
||||
mock_dataset_service_dependencies["check_permission"].assert_not_called()
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
|
||||
|
||||
def test_delete_dataset_permission_denied_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when user lacks permission.
|
||||
|
||||
Verifies that when the user doesn't have permission to delete
|
||||
the dataset, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Permission validation works correctly
|
||||
- Error is raised before deletion
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.delete_dataset(dataset_id, user)
|
||||
|
||||
# Verify no deletion was attempted
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for dataset_use_check
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceDatasetUseCheck:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.dataset_use_check method.
|
||||
|
||||
This test class covers the dataset use checking functionality, which
|
||||
determines if a dataset is currently being used by any applications.
|
||||
|
||||
The dataset_use_check method:
|
||||
1. Queries AppDatasetJoin table for the dataset ID
|
||||
2. Returns True if dataset is in use
|
||||
3. Returns False if dataset is not in use
|
||||
|
||||
Test scenarios include:
|
||||
- Dataset in use (has AppDatasetJoin records)
|
||||
- Dataset not in use (no AppDatasetJoin records)
|
||||
- Database query validation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing.
|
||||
|
||||
Provides a mocked database session that can be used to verify
|
||||
query construction and execution.
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_dataset_use_check_in_use(self, mock_db_session):
|
||||
"""
|
||||
Test detection when dataset is in use.
|
||||
|
||||
Verifies that when a dataset has associated AppDatasetJoin records,
|
||||
the method returns True.
|
||||
|
||||
This test ensures:
|
||||
- Query is constructed correctly
|
||||
- True is returned when dataset is in use
|
||||
- Database query is executed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
|
||||
# Mock the exists() query to return True
|
||||
mock_execute = Mock()
|
||||
mock_execute.scalar_one.return_value = True
|
||||
mock_db_session.execute.return_value = mock_execute
|
||||
|
||||
# Act
|
||||
result = DatasetService.dataset_use_check(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify query was executed
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
def test_dataset_use_check_not_in_use(self, mock_db_session):
|
||||
"""
|
||||
Test detection when dataset is not in use.
|
||||
|
||||
Verifies that when a dataset has no associated AppDatasetJoin records,
|
||||
the method returns False.
|
||||
|
||||
This test ensures:
|
||||
- Query is constructed correctly
|
||||
- False is returned when dataset is not in use
|
||||
- Database query is executed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
|
||||
# Mock the exists() query to return False
|
||||
mock_execute = Mock()
|
||||
mock_execute.scalar_one.return_value = False
|
||||
mock_db_session.execute.return_value = mock_execute
|
||||
|
||||
# Act
|
||||
result = DatasetService.dataset_use_check(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
# Verify query was executed
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for update_dataset_api_status
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_dataset_api_status method.
|
||||
|
||||
This test class covers the dataset API status update functionality,
|
||||
which enables or disables API access for a dataset.
|
||||
|
||||
The update_dataset_api_status method:
|
||||
1. Retrieves the dataset by ID
|
||||
2. Validates dataset exists
|
||||
3. Updates enable_api field
|
||||
4. Updates updated_by and updated_at fields
|
||||
5. Commits transaction
|
||||
|
||||
Test scenarios include:
|
||||
- Successful API status enable
|
||||
- Successful API status disable
|
||||
- Dataset not found error
|
||||
- Current user validation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""
|
||||
Mock dataset service dependencies for testing.
|
||||
|
||||
Provides mocked dependencies including:
|
||||
- get_dataset method
|
||||
- current_user context
|
||||
- Database session
|
||||
- Current time utilities
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
mock_current_user.id = "user-123"
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"current_user": mock_current_user,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
def test_update_dataset_api_status_enable_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful enabling of dataset API access.
|
||||
|
||||
Verifies that when all validation passes, the dataset's API
|
||||
access is enabled and the update is committed.
|
||||
|
||||
This test ensures:
|
||||
- Dataset is retrieved correctly
|
||||
- enable_api is set to True
|
||||
- updated_by and updated_at are set
|
||||
- Transaction is committed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id, enable_api=False)
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
DatasetService.update_dataset_api_status(dataset_id, True)
|
||||
|
||||
# Assert
|
||||
assert dataset.enable_api is True
|
||||
assert dataset.updated_by == "user-123"
|
||||
assert dataset.updated_at == mock_dataset_service_dependencies["current_time"]
|
||||
|
||||
# Verify dataset was retrieved
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
|
||||
|
||||
# Verify transaction was committed
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_update_dataset_api_status_disable_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful disabling of dataset API access.
|
||||
|
||||
Verifies that when all validation passes, the dataset's API
|
||||
access is disabled and the update is committed.
|
||||
|
||||
This test ensures:
|
||||
- Dataset is retrieved correctly
|
||||
- enable_api is set to False
|
||||
- updated_by and updated_at are set
|
||||
- Transaction is committed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id, enable_api=True)
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
DatasetService.update_dataset_api_status(dataset_id, False)
|
||||
|
||||
# Assert
|
||||
assert dataset.enable_api is False
|
||||
assert dataset.updated_by == "user-123"
|
||||
|
||||
# Verify transaction was committed
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_update_dataset_api_status_not_found_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when dataset is not found.
|
||||
|
||||
Verifies that when the dataset ID doesn't exist, a NotFound
|
||||
exception is raised.
|
||||
|
||||
This test ensures:
|
||||
- NotFound exception is raised
|
||||
- No updates are performed
|
||||
- Error message is appropriate
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "non-existent-dataset"
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetService.update_dataset_api_status(dataset_id, True)
|
||||
|
||||
# Verify no commit was attempted
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
|
||||
|
||||
def test_update_dataset_api_status_missing_current_user_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when current_user is missing.
|
||||
|
||||
Verifies that when current_user is None or has no ID, a ValueError
|
||||
is raised.
|
||||
|
||||
This test ensures:
|
||||
- ValueError is raised when current_user is None
|
||||
- Error message is clear
|
||||
- No updates are committed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["current_user"].id = None # Missing user ID
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.update_dataset_api_status(dataset_id, True)
|
||||
|
||||
# Verify no commit was attempted
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for update_rag_pipeline_dataset_settings
|
||||
# ============================================================================
|
||||
@ -1058,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
|
||||
# Mock embedding model
|
||||
mock_embedding_model = Mock()
|
||||
mock_embedding_model.model = "text-embedding-ada-002"
|
||||
mock_embedding_model.model_name = "text-embedding-ada-002"
|
||||
mock_embedding_model.provider = "openai"
|
||||
mock_embedding_model.credentials = {}
|
||||
|
||||
mock_model_schema = Mock()
|
||||
mock_model_schema.features = []
|
||||
|
||||
mock_text_embedding_model = Mock()
|
||||
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
|
||||
mock_embedding_model.model_type_instance = mock_text_embedding_model
|
||||
|
||||
mock_model_instance = Mock()
|
||||
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,141 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = response
|
||||
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
assert isinstance(result, DefaultWorkspaceJoinResult)
|
||||
assert result.workspace_id == response["workspace_id"]
|
||||
assert result.joined is True
|
||||
assert result.message == "ok"
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = "not-a-dict"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response format"):
|
||||
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
def test_join_default_workspace_invalid_account_id_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
|
||||
|
||||
def test_join_default_workspace_missing_required_fields_raises(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
response = {"workspace_id": "", "message": "ok"} # missing "joined"
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = response
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response payload"):
|
||||
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
def test_join_default_workspace_joined_without_workspace_id_raises(self):
|
||||
with pytest.raises(ValueError, match="workspace_id must be non-empty when joined is True"):
|
||||
DefaultWorkspaceJoinResult(workspace_id="", joined=True, message="ok")
|
||||
|
||||
|
||||
class TestTryJoinDefaultWorkspace:
|
||||
def test_try_join_default_workspace_enterprise_disabled_noop(self):
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
mock_join.assert_not_called()
|
||||
|
||||
def test_try_join_default_workspace_successful_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="22222222-2222-2222-2222-222222222222",
|
||||
joined=True,
|
||||
message="ok",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="",
|
||||
joined=False,
|
||||
message="no default workspace configured",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_api_failure_soft_fails(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.side_effect = Exception("network failure")
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
|
||||
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
@ -27,7 +27,7 @@ class TestTraceparentPropagation:
|
||||
@pytest.fixture
|
||||
def mock_httpx_client(self):
|
||||
"""Mock httpx.Client for testing."""
|
||||
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
|
||||
with patch("services.enterprise.base.httpx.Client", autospec=True) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client
|
||||
mock_client_class.return_value.__exit__.return_value = None
|
||||
@ -44,7 +44,9 @@ class TestTraceparentPropagation:
|
||||
# Arrange
|
||||
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
|
||||
with patch(
|
||||
"services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent, autospec=True
|
||||
):
|
||||
# Act
|
||||
EnterpriseRequest.send_request("GET", "/test")
|
||||
|
||||
|
||||
@ -135,8 +135,8 @@ class TestExternalDatasetServiceGetExternalKnowledgeApis:
|
||||
"""
|
||||
|
||||
with (
|
||||
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
|
||||
patch("services.external_knowledge_service.select"),
|
||||
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
|
||||
patch("services.external_knowledge_service.select", autospec=True),
|
||||
):
|
||||
yield mock_paginate
|
||||
|
||||
@ -245,7 +245,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Patch ``db.session`` for all CRUD tests in this class.
|
||||
"""
|
||||
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
@ -263,7 +263,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
}
|
||||
|
||||
# We do not want to actually call the remote endpoint here, so we patch the validator.
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
|
||||
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApis)
|
||||
@ -386,7 +386,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
|
||||
@ -447,7 +447,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
|
||||
@ -520,7 +520,7 @@ class TestExternalDatasetServiceProcessExternalApi:
|
||||
|
||||
fake_response = httpx.Response(200)
|
||||
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
|
||||
mock_post.return_value = fake_response
|
||||
|
||||
result = ExternalDatasetService.process_external_api(settings, files=None)
|
||||
@ -681,7 +681,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
|
||||
@ -801,7 +801,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
|
||||
@ -838,7 +838,9 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
|
||||
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
|
||||
with patch.object(
|
||||
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
|
||||
) as mock_process:
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@ -908,7 +910,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
fake_response.status_code = 500
|
||||
fake_response.json.return_value = {}
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
|
||||
@ -146,7 +146,7 @@ class TestHitTestingServiceRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
|
||||
@ -174,9 +174,11 @@ class TestHitTestingServiceRetrieve:
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
|
||||
mock_retrieve.return_value = documents
|
||||
@ -218,9 +220,11 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
@ -268,10 +272,12 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
@ -311,8 +317,10 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
):
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
mock_format.return_value = []
|
||||
@ -346,9 +354,11 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
@ -380,7 +390,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_external_retrieve_success(self, mock_db_session):
|
||||
@ -403,8 +413,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
@ -467,8 +479,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
@ -499,8 +513,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
metadata_filtering_conditions = {}
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = []
|
||||
@ -542,7 +558,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
|
||||
]
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
with patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
@ -566,7 +584,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
query = "test query"
|
||||
documents = []
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
with patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format:
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
|
||||
@ -147,7 +147,7 @@ class TestSegmentServiceCreateSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -172,10 +172,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -219,10 +221,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -257,11 +261,13 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager_class,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -292,10 +298,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -317,7 +325,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -338,10 +346,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None # Not indexing
|
||||
mock_hash.return_value = "new-hash"
|
||||
@ -368,10 +376,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
args = SegmentUpdateArgs(enabled=False)
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
|
||||
patch("services.dataset_service.disable_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
|
||||
patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -394,7 +402,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
args = SegmentUpdateArgs(content="Updated content")
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = "1" # Indexing in progress
|
||||
|
||||
# Act & Assert
|
||||
@ -409,7 +417,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
args = SegmentUpdateArgs(content="Updated content")
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -427,10 +435,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_hash.return_value = "new-hash"
|
||||
@ -456,7 +464,7 @@ class TestSegmentServiceDeleteSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_delete_segment_success(self, mock_db_session):
|
||||
@ -471,10 +479,10 @@ class TestSegmentServiceDeleteSegment:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_select.return_value.where.return_value = mock_select
|
||||
@ -495,8 +503,8 @@ class TestSegmentServiceDeleteSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
|
||||
@ -515,7 +523,7 @@ class TestSegmentServiceDeleteSegment:
|
||||
document = SegmentTestDataFactory.create_document_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = "1" # Deletion in progress
|
||||
|
||||
# Act & Assert
|
||||
@ -529,7 +537,7 @@ class TestSegmentServiceDeleteSegments:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -562,8 +570,8 @@ class TestSegmentServiceDeleteSegments:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_select_func.return_value = mock_select
|
||||
|
||||
@ -594,7 +602,7 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -623,9 +631,9 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.enable_segments_to_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_select_func.return_value = mock_select
|
||||
@ -657,10 +665,10 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.disable_segments_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -693,7 +701,7 @@ class TestSegmentServiceGetSegments:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -771,7 +779,7 @@ class TestSegmentServiceGetSegmentById:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_segment_by_id_success(self, mock_db_session):
|
||||
@ -814,7 +822,7 @@ class TestSegmentServiceGetChildChunks:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -876,7 +884,7 @@ class TestSegmentServiceGetChildChunkById:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_child_chunk_by_id_success(self, mock_db_session):
|
||||
@ -919,7 +927,7 @@ class TestSegmentServiceCreateChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -942,9 +950,11 @@ class TestSegmentServiceCreateChildChunk:
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -972,9 +982,11 @@ class TestSegmentServiceCreateChildChunk:
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -994,7 +1006,7 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -1014,8 +1026,10 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
|
||||
@ -1040,8 +1054,10 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_vector_service.side_effect = Exception("Vector indexing failed")
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -1059,7 +1075,7 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_delete_child_chunk_success(self, mock_db_session):
|
||||
@ -1068,7 +1084,9 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
chunk = SegmentTestDataFactory.create_child_chunk_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
|
||||
with patch(
|
||||
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service:
|
||||
# Act
|
||||
SegmentService.delete_child_chunk(chunk, dataset)
|
||||
|
||||
@ -1083,7 +1101,9 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
chunk = SegmentTestDataFactory.create_child_chunk_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
|
||||
with patch(
|
||||
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service:
|
||||
mock_vector_service.side_effect = Exception("Vector deletion failed")
|
||||
|
||||
# Act & Assert
|
||||
|
||||
@ -1064,6 +1064,67 @@ class TestRegisterService:
|
||||
|
||||
# ==================== Registration Tests ====================
|
||||
|
||||
def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
result = AccountService.create_account_and_tenant(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
|
||||
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
AccountService.create_account_and_tenant(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
)
|
||||
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
|
||||
def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
"""Test successful account registration."""
|
||||
# Setup mocks
|
||||
@ -1115,6 +1176,65 @@ class TestRegisterService:
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_register_calls_default_workspace_join_when_enterprise_enabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should be invoked after successful register commit."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
result = RegisterService.register(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
|
||||
def test_register_does_not_call_default_workspace_join_when_enterprise_disabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
RegisterService.register(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
|
||||
def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
"""Test account registration with OAuth integration."""
|
||||
# Setup mocks
|
||||
|
||||
@ -63,3 +63,56 @@ def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert pause_state_config is not None
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
|
||||
def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch):
|
||||
"""
|
||||
Regression test: ADVANCED_CHAT in blocking mode should return a plain dict
|
||||
(non-streaming), and must not go through the async retrieve_events path.
|
||||
Keeps behavior consistent with WORKFLOW blocking branch.
|
||||
"""
|
||||
# Disable billing and stub RateLimit to a no-op that just passes values through
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
|
||||
# Arrange a fake workflow and wire AppGenerateService._get_workflow to return it
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
# Spy on the streaming retrieval path to ensure it's NOT called
|
||||
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
|
||||
|
||||
# Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False
|
||||
generate_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
# Minimal app model for ADVANCED_CHAT
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
# Must include query and inputs for AdvancedChatAppGenerator
|
||||
args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}}
|
||||
|
||||
# Act: call service with streaming=False (blocking mode)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# Assert: returns the dict from generate(), and did not call retrieve_events()
|
||||
assert result == {"result": "ok"}
|
||||
assert generate_spy.call_args.kwargs.get("streaming") is False
|
||||
retrieve_spy.assert_not_called()
|
||||
|
||||
@ -44,9 +44,10 @@ class TestAppTaskService:
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
if should_call_graph_engine:
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
mock_graph_engine_manager.assert_called_once()
|
||||
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
|
||||
else:
|
||||
mock_graph_engine_manager.send_stop_command.assert_not_called()
|
||||
mock_graph_engine_manager.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_from",
|
||||
@ -76,7 +77,8 @@ class TestAppTaskService:
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
mock_graph_engine_manager.assert_called_once()
|
||||
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
|
||||
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@ -96,7 +98,7 @@ class TestAppTaskService:
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Simulate GraphEngine failure
|
||||
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
mock_graph_engine_manager.return_value.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
|
||||
# Act & Assert - should raise the exception since it's not caught
|
||||
with pytest.raises(Exception, match="GraphEngine error"):
|
||||
|
||||
@ -15,8 +15,8 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
class TestWorkflowRunArchiver:
|
||||
"""Tests for the WorkflowRunArchiver class."""
|
||||
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
|
||||
def test_archiver_initialization(self, mock_get_storage, mock_config):
|
||||
"""Test archiver can be initialized with various options."""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
@ -214,7 +214,7 @@ def factory():
|
||||
class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
@ -226,9 +226,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -242,7 +240,7 @@ class TestAudioServiceASR:
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
@ -254,9 +252,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -351,7 +347,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
@ -363,8 +359,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -375,7 +370,7 @@ class TestAudioServiceASR:
|
||||
class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
@ -388,9 +383,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -412,8 +405,8 @@ class TestAudioServiceTTS:
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
# Arrange
|
||||
@ -437,9 +430,7 @@ class TestAudioServiceTTS:
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio from message"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -454,7 +445,7 @@ class TestAudioServiceTTS:
|
||||
assert result == b"audio from message"
|
||||
mock_model_instance.invoke_tts.assert_called_once()
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
@ -467,9 +458,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -486,7 +475,7 @@ class TestAudioServiceTTS:
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
@ -499,9 +488,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
@ -518,8 +505,8 @@ class TestAudioServiceTTS:
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "auto-voice"
|
||||
|
||||
@patch("services.audio_service.WorkflowService")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.WorkflowService", autospec=True)
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
):
|
||||
@ -533,14 +520,11 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow_service = MagicMock()
|
||||
mock_workflow_service_class.return_value = mock_workflow_service
|
||||
mock_workflow_service = mock_workflow_service_class.return_value
|
||||
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -565,7 +549,7 @@ class TestAudioServiceTTS:
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
@ -580,7 +564,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
@ -601,7 +585,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
@ -627,7 +611,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
@ -640,9 +624,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [] # No voices available
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -655,7 +637,7 @@ class TestAudioServiceTTS:
|
||||
class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
@ -668,9 +650,7 @@ class TestAudioServiceTTSVoices:
|
||||
]
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = expected_voices
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -682,7 +662,7 @@ class TestAudioServiceTTSVoices:
|
||||
assert result == expected_voices
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
@ -690,15 +670,14 @@ class TestAudioServiceTTSVoices:
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
@ -706,9 +685,7 @@ class TestAudioServiceTTSVoices:
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
|
||||
@ -1,472 +0,0 @@
|
||||
"""
|
||||
Unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests the retrieval of document segments with pagination and filtering:
|
||||
- Basic pagination (page, limit)
|
||||
- Status filtering
|
||||
- Keyword search
|
||||
- Ordering by position and id (to avoid duplicate data)
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class SegmentServiceTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for segment tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_segment_mock(
|
||||
segment_id: str = "segment-123",
|
||||
document_id: str = "doc-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
position: int = 1,
|
||||
content: str = "Test content",
|
||||
status: str = "completed",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock document segment.
|
||||
|
||||
Args:
|
||||
segment_id: Unique identifier for the segment
|
||||
document_id: Parent document ID
|
||||
tenant_id: Tenant ID the segment belongs to
|
||||
dataset_id: Parent dataset ID
|
||||
position: Position within the document
|
||||
content: Segment text content
|
||||
status: Indexing status
|
||||
**kwargs: Additional attributes
|
||||
|
||||
Returns:
|
||||
Mock: DocumentSegment mock object
|
||||
"""
|
||||
segment = create_autospec(DocumentSegment, instance=True)
|
||||
segment.id = segment_id
|
||||
segment.document_id = document_id
|
||||
segment.tenant_id = tenant_id
|
||||
segment.dataset_id = dataset_id
|
||||
segment.position = position
|
||||
segment.content = content
|
||||
segment.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(segment, key, value)
|
||||
return segment
|
||||
|
||||
|
||||
class TestSegmentServiceGetSegments:
|
||||
"""
|
||||
Comprehensive unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests cover:
|
||||
- Basic pagination functionality
|
||||
- Status list filtering
|
||||
- Keyword search filtering
|
||||
- Ordering (position + id for uniqueness)
|
||||
- Empty results
|
||||
- Combined filters
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment_service_dependencies(self):
|
||||
"""
|
||||
Common mock setup for segment service dependencies.
|
||||
|
||||
Patches:
|
||||
- db: Database operations and pagination
|
||||
- select: SQLAlchemy query builder
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select") as mock_select,
|
||||
):
|
||||
yield {
|
||||
"db": mock_db,
|
||||
"select": mock_select,
|
||||
}
|
||||
|
||||
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test basic pagination functionality.
|
||||
|
||||
Verifies:
|
||||
- Query is built with document_id and tenant_id filters
|
||||
- Pagination uses correct page and limit parameters
|
||||
- Returns segments and total count
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
page = 1
|
||||
limit = 20
|
||||
|
||||
# Create mock segments
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="First segment"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=2, content="Second segment"
|
||||
)
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
# Mock select builder
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
assert items[0].id == "seg-1"
|
||||
assert items[1].id == "seg-2"
|
||||
mock_segment_service_dependencies["db"].paginate.assert_called_once()
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test filtering by status list.
|
||||
|
||||
Verifies:
|
||||
- Status list filter is applied to query
|
||||
- Only segments with matching status are returned
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed", "indexing"]
|
||||
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
# Verify where was called multiple times (base filters + status filter)
|
||||
assert mock_query.where.call_count >= 2
|
||||
|
||||
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with empty status list.
|
||||
|
||||
Verifies:
|
||||
- Empty status list is handled correctly
|
||||
- No status filter is applied to avoid WHERE false condition
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = []
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test keyword search functionality.
|
||||
|
||||
Verifies:
|
||||
- Keyword filter uses ilike for case-insensitive search
|
||||
- Search pattern includes wildcards (%keyword%)
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
keyword = "search term"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", content="This contains search term"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify where was called for base filters + keyword filter
|
||||
assert mock_query.where.call_count == 2
|
||||
|
||||
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test ordering by position and id.
|
||||
|
||||
Verifies:
|
||||
- Results are ordered by position ASC
|
||||
- Results are secondarily ordered by id ASC to ensure uniqueness
|
||||
- This prevents duplicate data across pages when positions are not unique
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
# Create segments with same position but different ids
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="Content 1"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=1, content="Content 2"
|
||||
)
|
||||
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-3", position=2, content="Content 3"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2, segment3]
|
||||
mock_paginated.total = 3
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 3
|
||||
assert total == 3
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test when no segments match the criteria.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned for items
|
||||
- Total count is 0
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "non-existent-doc"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with multiple filters combined.
|
||||
|
||||
Verifies:
|
||||
- All filters work together correctly
|
||||
- Status list and keyword search both applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed"]
|
||||
keyword = "important"
|
||||
page = 2
|
||||
limit = 10
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1",
|
||||
status="completed",
|
||||
content="This is important information",
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=status_list,
|
||||
keyword=keyword,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify filters: base + status + keyword
|
||||
assert mock_query.where.call_count == 3
|
||||
# Verify pagination parameters
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
|
||||
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with None status list.
|
||||
|
||||
Verifies:
|
||||
- None status list is handled correctly
|
||||
- No status filter is applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters only, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test that max_per_page is correctly set to 100.
|
||||
|
||||
Verifies:
|
||||
- max_per_page parameter is set to 100
|
||||
- This prevents excessive page sizes
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
limit = 200 # Request more than max_per_page
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
@ -1,746 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService retrieval/list methods.
|
||||
|
||||
This test suite covers:
|
||||
- get_datasets - pagination, search, filtering, permissions
|
||||
- get_dataset - single dataset retrieval
|
||||
- get_datasets_by_ids - bulk retrieval
|
||||
- get_process_rules - dataset processing rules
|
||||
- get_dataset_queries - dataset query history
|
||||
- get_related_apps - apps using the dataset
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
|
||||
class DatasetRetrievalTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset retrieval tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
name: str = "Test Dataset",
|
||||
tenant_id: str = "tenant-123",
|
||||
created_by: str = "user-123",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.name = name
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock account."""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
account.current_tenant_id = tenant_id
|
||||
account.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "account-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
mode: str = "automatic",
|
||||
rules: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset process rule."""
|
||||
process_rule = Mock(spec=DatasetProcessRule)
|
||||
process_rule.dataset_id = dataset_id
|
||||
process_rule.mode = mode
|
||||
process_rule.rules_dict = rules or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(process_rule, key, value)
|
||||
return process_rule
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_query_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
query_id: str = "query-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset query."""
|
||||
dataset_query = Mock(spec=DatasetQuery)
|
||||
dataset_query.id = query_id
|
||||
dataset_query.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset_query, key, value)
|
||||
return dataset_query
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join_mock(
|
||||
app_id: str = "app-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock app-dataset join."""
|
||||
join = Mock(spec=AppDatasetJoin)
|
||||
join.app_id = app_id
|
||||
join.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(join, key, value)
|
||||
return join
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasets:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.get_datasets method.
|
||||
|
||||
This test suite covers:
|
||||
- Pagination
|
||||
- Search functionality
|
||||
- Tag filtering
|
||||
- Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL)
|
||||
- include_all flag
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets tests."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.db.paginate") as mock_paginate,
|
||||
patch("services.dataset_service.TagService") as mock_tag_service,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"paginate": mock_paginate,
|
||||
"tag_service": mock_tag_service,
|
||||
}
|
||||
|
||||
# ==================== Basic Retrieval Tests ====================
|
||||
|
||||
def test_get_datasets_basic_pagination(self, mock_dependencies):
|
||||
"""Test basic pagination without user or filters."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
mock_paginate_result.total = 5
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 5
|
||||
assert total == 5
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_search(self, mock_dependencies):
|
||||
"""Test get_datasets with search keyword."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
search = "test"
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_tag_filtering(self, mock_dependencies):
|
||||
"""Test get_datasets with tag_ids filtering."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = ["tag-1", "tag-2"]
|
||||
|
||||
# Mock tag service
|
||||
target_ids = ["dataset-1", "dataset-2"]
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in target_ids
|
||||
]
|
||||
mock_paginate_result.total = 2
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 2
|
||||
assert total == 2
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with(
|
||||
"knowledge", tenant_id, tag_ids
|
||||
)
|
||||
|
||||
def test_get_datasets_with_empty_tag_ids(self, mock_dependencies):
|
||||
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = []
|
||||
|
||||
# Mock pagination result - when tag_ids is empty, tag filtering is skipped
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
# Tag service should not be called when tag_ids is empty
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called()
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
# ==================== Permission-Based Filtering Tests ====================
|
||||
|
||||
def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies):
|
||||
"""Test that without user, only ALL_TEAM datasets are shown."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_owner_with_include_all(self, mock_dependencies):
|
||||
"""Test that OWNER with include_all=True sees all datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (empty - owner doesn't need explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
|
||||
def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ONLY_ME datasets they created."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
created_by=user_id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ALL_TEAM datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - user has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - operator has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR without permissions returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - no permissions
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetDataset:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_dataset_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of a single dataset."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = dataset
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == dataset_id
|
||||
mock_query.filter_by.assert_called_once_with(id=dataset_id)
|
||||
|
||||
def test_get_dataset_not_found(self, mock_dependencies):
|
||||
"""Test retrieval when dataset doesn't exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetsByIds:
|
||||
"""Comprehensive unit tests for DatasetService.get_datasets_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets_by_ids tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_datasets_by_ids_success(self, mock_dependencies):
|
||||
"""Test successful bulk retrieval of datasets by IDs."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
mock_paginate_result.total = len(dataset_ids)
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
assert all(dataset.id in dataset_ids for dataset in datasets)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_by_ids_empty_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with empty list returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = []
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
def test_get_datasets_by_ids_none_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with None returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
|
||||
class TestDatasetServiceGetProcessRules:
|
||||
"""Comprehensive unit tests for DatasetService.get_process_rules method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_process_rules tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_process_rules_with_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when rule exists."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
rules_data = {
|
||||
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500},
|
||||
}
|
||||
process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock(
|
||||
dataset_id=dataset_id, mode="custom", rules=rules_data
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == "custom"
|
||||
assert result["rules"] == rules_data
|
||||
|
||||
def test_get_process_rules_without_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when no rule exists (returns defaults)."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == DocumentService.DEFAULT_RULES["mode"]
|
||||
assert "rules" in result
|
||||
assert result["rules"] == DocumentService.DEFAULT_RULES["rules"]
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetQueries:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset_queries method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset_queries tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_dataset_queries_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of dataset queries."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert len(queries) == 3
|
||||
assert total == 3
|
||||
assert all(query.dataset_id == dataset_id for query in queries)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_dataset_queries_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no queries exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result (empty)
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = []
|
||||
mock_paginate_result.total = 0
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert queries == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetRelatedApps:
|
||||
"""Comprehensive unit tests for DatasetService.get_related_apps method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_related_apps tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_related_apps_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of related apps."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock app-dataset joins
|
||||
app_joins = [
|
||||
DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = app_joins
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(join.dataset_id == dataset_id for join in result)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.where.return_value.order_by.assert_called_once()
|
||||
|
||||
def test_get_related_apps_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no related apps exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
@ -1,661 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetUpdateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset update tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.provider = provider
|
||||
dataset.name = name
|
||||
dataset.description = description
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.retrieval_model = retrieval_model
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
dataset.collection_binding_id = collection_binding_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
||||
"""Create a mock user."""
|
||||
user = Mock()
|
||||
user.id = user_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding_mock(
|
||||
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
|
||||
) -> Mock:
|
||||
"""Create a mock external knowledge binding."""
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
binding.external_knowledge_id = external_knowledge_id
|
||||
binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
|
||||
"""Create a mock collection binding."""
|
||||
binding = Mock()
|
||||
binding.id = binding_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
|
||||
"""Create a mock current user."""
|
||||
current_user = create_autospec(Account, instance=True)
|
||||
current_user.current_tenant_id = tenant_id
|
||||
return current_user
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_dataset method.
|
||||
|
||||
This test suite covers all supported scenarios including:
|
||||
- External dataset updates
|
||||
- Internal dataset updates with different indexing techniques
|
||||
- Embedding model updates
|
||||
- Permission checks
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_provider_dependencies(self):
|
||||
"""Mock setup for external provider tests."""
|
||||
with patch("services.dataset_service.Session") as mock_session:
|
||||
from extensions.ext_database import db
|
||||
|
||||
with patch.object(db.__class__, "engine", new_callable=Mock):
|
||||
session_mock = Mock()
|
||||
mock_session.return_value.__enter__.return_value = session_mock
|
||||
yield session_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_internal_provider_dependencies(self):
|
||||
"""Mock setup for internal provider tests."""
|
||||
with (
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
):
|
||||
mock_current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
"model_manager": mock_model_manager,
|
||||
"get_binding": mock_get_binding,
|
||||
"task": mock_task,
|
||||
"regenerate_task": mock_regenerate_task,
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
|
||||
"""Helper method to verify database update calls."""
|
||||
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
|
||||
"""Helper method to verify external dataset updates."""
|
||||
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
|
||||
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
|
||||
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
|
||||
|
||||
if "external_knowledge_id" in update_data:
|
||||
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
|
||||
if "external_knowledge_api_id" in update_data:
|
||||
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||
|
||||
# ==================== External Dataset Tests ====================
|
||||
|
||||
def test_update_external_dataset_success(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test successful update of external dataset."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
|
||||
|
||||
# Mock external knowledge binding query
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"external_retrieval_model": "new_model",
|
||||
"permission": "only_me",
|
||||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
self._assert_external_dataset_update(dataset, binding, update_data)
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add.assert_any_call(dataset)
|
||||
mock_db.add.assert_any_call(binding)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge api id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge api id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_binding_not_found_error(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock external knowledge binding query returning None
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge binding not found" in str(context.value)
|
||||
|
||||
# ==================== Internal Dataset Basic Tests ====================
|
||||
|
||||
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful update of internal dataset with basic fields."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify database update was called with correct filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
|
||||
"""Test that None values are filtered out except for description field."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Should be included
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None, # Should be filtered out
|
||||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Description should be included even if None
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
actual_call_args = mock_dataset_service_dependencies[
|
||||
"db_session"
|
||||
].query.return_value.filter_by.return_value.update.call_args[0][0]
|
||||
# Remove timestamp for comparison as it's dynamic
|
||||
del actual_call_args["updated_at"]
|
||||
del expected_filtered_data["updated_at"]
|
||||
|
||||
assert actual_call_args == expected_filtered_data
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Indexing Technique Switch Tests ====================
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to economy."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with embedding model fields cleared
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "economy",
|
||||
"embedding_model": None,
|
||||
"embedding_model_provider": None,
|
||||
"collection_binding_id": None,
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to high_quality."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-456",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Embedding Model Update Tests ====================
|
||||
|
||||
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with existing embedding model preserved
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_embedding_model_update(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset with new embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-789",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
|
||||
|
||||
# Verify regenerate summary index task was triggered (when embedding_model changes)
|
||||
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
|
||||
"dataset-123",
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing indexing technique."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when dataset is not found."""
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "Dataset not found" in str(context.value)
|
||||
|
||||
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when user doesn't have permission."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
def test_update_internal_dataset_embedding_model_error(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test error when embedding model is not available."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock model manager to raise error
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
|
||||
"No Embedding Model available"
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||
@ -6,66 +6,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestArchivedWorkflowRunDeletion:
|
||||
def test_delete_by_run_id_returns_error_when_run_missing(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
repo = MagicMock()
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
|
||||
):
|
||||
result = deleter.delete_by_run_id("run-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Workflow run run-1 not found"
|
||||
repo.get_archived_run_ids.assert_not_called()
|
||||
|
||||
def test_delete_by_run_id_returns_error_when_not_archived(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
repo = MagicMock()
|
||||
repo.get_archived_run_ids.return_value = set()
|
||||
run = MagicMock()
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
session = MagicMock()
|
||||
session.get.return_value = run
|
||||
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
|
||||
patch.object(deleter, "_delete_run") as mock_delete_run,
|
||||
):
|
||||
result = deleter.delete_by_run_id("run-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Workflow run run-1 is not archived"
|
||||
mock_delete_run.assert_not_called()
|
||||
|
||||
def test_delete_by_run_id_calls_delete_run(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
@ -88,65 +28,20 @@ class TestArchivedWorkflowRunDeletion:
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker",
|
||||
return_value=session_maker,
|
||||
autospec=True,
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
|
||||
patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run,
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True),
|
||||
patch.object(
|
||||
deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True
|
||||
) as mock_delete_run,
|
||||
):
|
||||
result = deleter.delete_by_run_id("run-1")
|
||||
|
||||
assert result.success is True
|
||||
mock_delete_run.assert_called_once_with(run)
|
||||
|
||||
def test_delete_batch_uses_repo(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
repo = MagicMock()
|
||||
run1 = MagicMock()
|
||||
run1.id = "run-1"
|
||||
run1.tenant_id = "tenant-1"
|
||||
run2 = MagicMock()
|
||||
run2.id = "run-2"
|
||||
run2.tenant_id = "tenant-1"
|
||||
repo.get_archived_runs_by_time_range.return_value = [run1, run2]
|
||||
|
||||
session = MagicMock()
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
start_date = MagicMock()
|
||||
end_date = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
|
||||
patch.object(
|
||||
deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)]
|
||||
) as mock_delete_run,
|
||||
):
|
||||
results = deleter.delete_batch(
|
||||
tenant_ids=["tenant-1"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
repo.get_archived_runs_by_time_range.assert_called_once_with(
|
||||
session=session,
|
||||
tenant_ids=["tenant-1"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=2,
|
||||
)
|
||||
assert mock_delete_run.call_count == 2
|
||||
|
||||
def test_delete_run_dry_run(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
@ -155,26 +50,8 @@ class TestArchivedWorkflowRunDeletion:
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
|
||||
with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo:
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
mock_get_repo.assert_not_called()
|
||||
|
||||
def test_delete_run_calls_repo(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
run = MagicMock()
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
repo = MagicMock()
|
||||
repo.delete_runs_with_related.return_value = {"runs": 1}
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo", return_value=repo):
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.deleted_counts == {"runs": 1}
|
||||
repo.delete_runs_with_related.assert_called_once()
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from models.dataset import Document
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
@ -9,25 +6,3 @@ def test_normalize_display_status_alias_mapping():
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
|
||||
|
||||
def test_build_display_status_filters_available():
|
||||
filters = DocumentService.build_display_status_filters("available")
|
||||
assert len(filters) == 3
|
||||
for condition in filters:
|
||||
assert condition is not None
|
||||
|
||||
|
||||
def test_apply_display_status_filter_applies_when_status_present():
|
||||
query = sa.select(Document)
|
||||
filtered = DocumentService.apply_display_status_filter(query, "queuing")
|
||||
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "WHERE" in compiled
|
||||
assert "documents.indexing_status = 'waiting'" in compiled
|
||||
|
||||
|
||||
def test_apply_display_status_filter_returns_same_when_invalid():
|
||||
query = sa.select(Document)
|
||||
filtered = DocumentService.apply_display_status_filter(query, "invalid")
|
||||
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "WHERE" not in compiled
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from models.model import App, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@ -44,113 +44,6 @@ class TestEndUserServiceFactory:
|
||||
return end_user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating new end users
|
||||
- Retrieving existing end users
|
||||
- Default session ID handling
|
||||
- Anonymous user creation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 01: Get or create with custom user_id
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user with custom user_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "custom-user-123"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
# Verify the created user has correct attributes
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.tenant_id == app.tenant_id
|
||||
assert added_user.app_id == app.id
|
||||
assert added_user.session_id == user_id
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.is_anonymous is False
|
||||
|
||||
# Test 02: Get or create without user_id (default session)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user without user_id uses default session."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
|
||||
# Test 03: Get existing end user
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
|
||||
"""Test retrieving an existing end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "existing-user-123"
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
mock_session.add.assert_not_called() # Should not create new user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user_by_type method.
|
||||
@ -167,226 +60,6 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 04: Create new end user with SERVICE_API type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with SERVICE_API type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.tenant_id == tenant_id
|
||||
assert added_user.app_id == app_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 05: Create new end user with WEB_APP type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with WEB_APP type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.WEB_APP
|
||||
|
||||
# Test 06: Upgrade legacy end user type
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test upgrading legacy end user with different type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
# Existing user with old type
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with different type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
# Verify log message contains upgrade info
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Upgrading legacy EndUser" in log_call
|
||||
|
||||
# Test 07: Get existing end user with matching type (no upgrade needed)
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test retrieving existing end user with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with same type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.SERVICE_API
|
||||
# No commit should be called (no type update needed)
|
||||
mock_session.commit.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
# Test 08: Create anonymous user with default session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating anonymous user when user_id is None."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
# Test 09: Query ordering prioritizes matching type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test that query ordering prioritizes records with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify order_by was called (for type prioritization)
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
# Test 10: Session context manager properly closes
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@ -420,117 +93,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
# Verify context manager was entered and exited
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
# Test 11: External user ID matches session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test that external_user_id is set to match session_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "custom-external-id"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.external_user_id == user_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 12: Different InvokeFrom types
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_type",
|
||||
[
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.EXPLORE,
|
||||
InvokeFrom.DEBUGGER,
|
||||
],
|
||||
)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
|
||||
"""Test creating end users with different InvokeFrom types."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=invoke_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
|
||||
class TestEndUserServiceGetEndUserById:
|
||||
"""Unit tests for EndUserService.get_end_user_by_id."""
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "end-user-789"
|
||||
existing_user = MagicMock(spec=EndUser)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
assert result == existing_user
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
assert len(mock_query.where.call_args[0]) == 3
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
|
||||
from services import message_service
|
||||
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self, message_id: str):
|
||||
self.id = message_id
|
||||
self.extra_contents = None
|
||||
|
||||
def set_extra_contents(self, contents):
|
||||
self.extra_contents = contents
|
||||
|
||||
|
||||
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_by_message_ids": lambda _self, message_ids: [
|
||||
[
|
||||
HumanInputContent(
|
||||
workflow_run_id="workflow-run-1",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
)
|
||||
],
|
||||
[],
|
||||
]
|
||||
},
|
||||
)()
|
||||
|
||||
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
|
||||
|
||||
message_service.attach_message_extra_contents(messages)
|
||||
|
||||
assert messages[0].extra_contents == [
|
||||
{
|
||||
"type": "human_input",
|
||||
"workflow_run_id": "workflow-run-1",
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"rendered_content": "Rendered",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
},
|
||||
}
|
||||
]
|
||||
assert messages[1].extra_contents == []
|
||||
@ -402,7 +402,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
|
||||
class TestCreateMessageCleanPolicy:
|
||||
"""Unit tests for create_message_clean_policy factory function."""
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
|
||||
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
|
||||
# Arrange
|
||||
@ -414,8 +414,8 @@ class TestCreateMessageCleanPolicy:
|
||||
# Assert
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.BillingService")
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
|
||||
"""Test that BillingSandboxPolicy is created with correct internal values."""
|
||||
# Arrange
|
||||
@ -554,7 +554,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
MessagesCleanService.from_days(policy=policy, days=-1)
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
@ -586,7 +586,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
dry_run = True
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
@ -613,7 +613,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
policy = BillingDisabledPolicy()
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
|
||||
@ -134,8 +134,8 @@ def factory():
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
@ -161,8 +161,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
@ -199,8 +199,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
@ -232,8 +232,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
@ -262,8 +262,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
@ -292,8 +292,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
@ -324,8 +324,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
@ -352,8 +352,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
@ -375,8 +375,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
@ -397,8 +397,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
|
||||
@ -3,7 +3,6 @@ Unit tests for workflow run restore functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestWorkflowRunRestore:
|
||||
@ -36,30 +35,3 @@ class TestWorkflowRunRestore:
|
||||
assert result["created_at"].year == 2024
|
||||
assert result["created_at"].month == 1
|
||||
assert result["name"] == "test"
|
||||
|
||||
def test_restore_table_records_returns_rowcount(self):
|
||||
"""Restore should return inserted rowcount."""
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
session = MagicMock()
|
||||
session.execute.return_value = MagicMock(rowcount=2)
|
||||
|
||||
restore = WorkflowRunRestore()
|
||||
records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}]
|
||||
|
||||
restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0")
|
||||
|
||||
assert restored == 2
|
||||
session.execute.assert_called_once()
|
||||
|
||||
def test_restore_table_records_unknown_table(self):
|
||||
"""Unknown table names should be ignored gracefully."""
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
restore = WorkflowRunRestore()
|
||||
restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0")
|
||||
|
||||
assert restored == 0
|
||||
session.execute.assert_not_called()
|
||||
|
||||
@ -201,8 +201,8 @@ def factory():
|
||||
class TestSavedMessageServicePagination:
|
||||
"""Test saved message pagination operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an Account user."""
|
||||
# Arrange
|
||||
@ -247,8 +247,8 @@ class TestSavedMessageServicePagination:
|
||||
include_ids=["msg-0", "msg-1", "msg-2"],
|
||||
)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an EndUser."""
|
||||
# Arrange
|
||||
@ -301,8 +301,8 @@ class TestSavedMessageServicePagination:
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with last_id parameter."""
|
||||
# Arrange
|
||||
@ -340,8 +340,8 @@ class TestSavedMessageServicePagination:
|
||||
call_args = mock_message_pagination.call_args
|
||||
assert call_args.kwargs["last_id"] == last_id
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination when user has no saved messages."""
|
||||
# Arrange
|
||||
@ -377,8 +377,8 @@ class TestSavedMessageServicePagination:
|
||||
class TestSavedMessageServiceSave:
|
||||
"""Test save message operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an Account user."""
|
||||
# Arrange
|
||||
@ -407,8 +407,8 @@ class TestSavedMessageServiceSave:
|
||||
assert saved_message.created_by_role == "account"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an EndUser."""
|
||||
# Arrange
|
||||
@ -437,7 +437,7 @@ class TestSavedMessageServiceSave:
|
||||
assert saved_message.created_by_role == "end_user"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that saving without user is a no-op."""
|
||||
# Arrange
|
||||
@ -451,8 +451,8 @@ class TestSavedMessageServiceSave:
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that saving an already saved message is idempotent."""
|
||||
# Arrange
|
||||
@ -480,8 +480,8 @@ class TestSavedMessageServiceSave:
|
||||
mock_db_session.commit.assert_not_called()
|
||||
mock_get_message.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that save validates message exists through MessageService."""
|
||||
# Arrange
|
||||
@ -508,7 +508,7 @@ class TestSavedMessageServiceSave:
|
||||
class TestSavedMessageServiceDelete:
|
||||
"""Test delete saved message operations."""
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_saved_message_for_account(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an Account user."""
|
||||
# Arrange
|
||||
@ -535,7 +535,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
# Arrange
|
||||
@ -562,7 +562,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting without user is a no-op."""
|
||||
# Arrange
|
||||
@ -576,7 +576,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
# Arrange
|
||||
@ -597,7 +597,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
|
||||
"""Test that delete only removes the user's own saved message."""
|
||||
# Arrange
|
||||
|
||||
@ -315,7 +315,7 @@ class TestTagServiceRetrieval:
|
||||
- get_tags_by_target_id: Get all tags bound to a specific target
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags with their binding counts.
|
||||
@ -372,7 +372,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify database query was called
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags filtered by keyword (case-insensitive).
|
||||
@ -426,7 +426,7 @@ class TestTagServiceRetrieval:
|
||||
# 2. Additional WHERE clause for keyword filtering
|
||||
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving target IDs by tag IDs.
|
||||
@ -482,7 +482,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify both queries were executed
|
||||
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that empty tag_ids returns empty list.
|
||||
@ -510,7 +510,7 @@ class TestTagServiceRetrieval:
|
||||
assert results == [], "Should return empty list for empty input"
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_by_tag_name(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags by name.
|
||||
@ -552,7 +552,7 @@ class TestTagServiceRetrieval:
|
||||
assert len(results) == 1, "Should find exactly one tag"
|
||||
assert results[0].name == tag_name, "Tag name should match"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that missing tag_type or tag_name returns empty list.
|
||||
@ -580,7 +580,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify no database queries were executed
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_by_target_id(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags associated with a specific target.
|
||||
@ -651,10 +651,10 @@ class TestTagServiceCRUD:
|
||||
- get_tag_binding_count: Get count of bindings for a tag
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.uuid.uuid4")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.uuid.uuid4", autospec=True)
|
||||
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test creating a new tag.
|
||||
@ -709,8 +709,8 @@ class TestTagServiceCRUD:
|
||||
assert added_tag.created_by == "user-123", "Created by should match current user"
|
||||
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test that creating a tag with duplicate name raises ValueError.
|
||||
@ -740,9 +740,9 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(ValueError, match="Tag name already exists"):
|
||||
TagService.save_tags(args)
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test updating a tag name.
|
||||
@ -792,9 +792,9 @@ class TestTagServiceCRUD:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags_raises_error_for_duplicate_name(
|
||||
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
|
||||
):
|
||||
@ -826,7 +826,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(ValueError, match="Tag name already exists"):
|
||||
TagService.update_tags(args, tag_id="tag-123")
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that updating a non-existent tag raises NotFound.
|
||||
@ -848,8 +848,8 @@ class TestTagServiceCRUD:
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock duplicate check and current_user
|
||||
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[]):
|
||||
with patch("services.tag_service.current_user") as mock_user:
|
||||
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True):
|
||||
with patch("services.tag_service.current_user", autospec=True) as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
args = {"name": "New Name", "type": "app"}
|
||||
|
||||
@ -858,7 +858,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(NotFound, match="Tag not found"):
|
||||
TagService.update_tags(args, tag_id="nonexistent")
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_binding_count(self, mock_db_session, factory):
|
||||
"""
|
||||
Test getting the count of bindings for a tag.
|
||||
@ -894,7 +894,7 @@ class TestTagServiceCRUD:
|
||||
# Verify count matches expectation
|
||||
assert result == expected_count, "Binding count should match"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test deleting a tag and its bindings.
|
||||
@ -950,7 +950,7 @@ class TestTagServiceCRUD:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that deleting a non-existent tag raises NotFound.
|
||||
@ -996,9 +996,9 @@ class TestTagServiceBindings:
|
||||
- check_target_exists: Validate target (dataset/app) existence
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test creating tag bindings.
|
||||
@ -1047,9 +1047,9 @@ class TestTagServiceBindings:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test that saving duplicate bindings is idempotent.
|
||||
@ -1088,8 +1088,8 @@ class TestTagServiceBindings:
|
||||
# Verify no new binding was added (idempotent)
|
||||
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test deleting a tag binding.
|
||||
@ -1136,8 +1136,8 @@ class TestTagServiceBindings:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test that deleting a non-existent binding is a no-op.
|
||||
@ -1173,8 +1173,8 @@ class TestTagServiceBindings:
|
||||
# Verify no commit was made (nothing changed)
|
||||
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that a dataset target exists.
|
||||
@ -1214,8 +1214,8 @@ class TestTagServiceBindings:
|
||||
# Verify no exception was raised and query was executed
|
||||
mock_db_session.query.assert_called_once(), "Should query database for dataset"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that an app target exists.
|
||||
@ -1255,8 +1255,8 @@ class TestTagServiceBindings:
|
||||
# Verify no exception was raised and query was executed
|
||||
mock_db_session.query.assert_called_once(), "Should query database for app"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_raises_not_found_for_missing_dataset(
|
||||
self, mock_db_session, mock_current_user, factory
|
||||
):
|
||||
@ -1287,8 +1287,8 @@ class TestTagServiceBindings:
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
TagService.check_target_exists("knowledge", "nonexistent")
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test that missing app raises NotFound.
|
||||
|
||||
@ -17,7 +17,9 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.variables.segments import (
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArraySegment,
|
||||
@ -28,8 +30,6 @@ from core.variables.segments import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.file.models import File
|
||||
from services.variable_truncator import (
|
||||
DummyVariableTruncator,
|
||||
MaxDepthExceededError,
|
||||
|
||||
@ -87,7 +87,7 @@ class TestWebhookServiceUnit:
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
|
||||
with patch.object(WebhookService, "_process_file_uploads", autospec=True) as mock_process_files:
|
||||
mock_process_files.return_value = {"file": "mocked_file_obj"}
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
@ -123,8 +123,10 @@ class TestWebhookServiceUnit:
|
||||
mock_file.to_dict.return_value = {"file": "data"}
|
||||
|
||||
with (
|
||||
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
|
||||
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
|
||||
patch.object(
|
||||
WebhookService, "_detect_binary_mimetype", return_value="text/plain", autospec=True
|
||||
) as mock_detect,
|
||||
patch.object(WebhookService, "_create_file_from_binary", autospec=True) as mock_create,
|
||||
):
|
||||
mock_create.return_value = mock_file
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
@ -168,7 +170,7 @@ class TestWebhookServiceUnit:
|
||||
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
|
||||
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
|
||||
|
||||
with patch("services.trigger.webhook_service.logger") as mock_logger:
|
||||
with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger:
|
||||
result = WebhookService._detect_binary_mimetype(b"binary data")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
@ -245,15 +247,12 @@ class TestWebhookServiceUnit:
|
||||
assert response_data[0]["id"] == 1
|
||||
assert response_data[1]["id"] == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
|
||||
@patch("services.trigger.webhook_service.file_factory", autospec=True)
|
||||
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test successful file upload processing."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
@ -285,15 +284,12 @@ class TestWebhookServiceUnit:
|
||||
assert mock_tool_file_manager.call_count == 2
|
||||
assert mock_file_factory.build_from_mapping.call_count == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
|
||||
@patch("services.trigger.webhook_service.file_factory", autospec=True)
|
||||
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test file upload processing with errors."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
@ -544,8 +540,8 @@ class TestWebhookServiceUnit:
|
||||
|
||||
# Mock the WebhookService methods
|
||||
with (
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow", autospec=True) as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data", autospec=True) as mock_extract,
|
||||
):
|
||||
mock_trigger = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
|
||||
@ -124,7 +124,7 @@ class TestWorkflowRunService:
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
@ -135,7 +135,7 @@ class TestWorkflowRunService:
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
@ -146,7 +146,7 @@ class TestWorkflowRunService:
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
@ -158,9 +158,11 @@ class TestWorkflowRunService:
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
with patch(
|
||||
"services.workflow_run_service.sessionmaker", return_value=session_factory, autospec=True
|
||||
) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
|
||||
@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
@ -1078,13 +1079,52 @@ class TestWorkflowService:
|
||||
mock_node_class = MagicMock()
|
||||
mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
|
||||
|
||||
mock_mapping.values.return_value = [{"latest": mock_node_class}]
|
||||
mock_mapping.items.return_value = [(NodeType.LLM, {"latest": mock_node_class})]
|
||||
|
||||
with patch("services.workflow_service.LATEST_VERSION", "latest"):
|
||||
result = workflow_service.get_default_block_configs()
|
||||
|
||||
assert len(result) > 0
|
||||
|
||||
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service):
|
||||
injected_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=15,
|
||||
max_read_timeout=25,
|
||||
max_write_timeout=35,
|
||||
max_binary_size=4096,
|
||||
max_text_size=2048,
|
||||
ssl_verify=True,
|
||||
ssrf_default_max_retries=6,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
|
||||
patch("services.workflow_service.LATEST_VERSION", "latest"),
|
||||
patch(
|
||||
"services.workflow_service.build_http_request_config",
|
||||
return_value=injected_config,
|
||||
) as mock_build_config,
|
||||
):
|
||||
mock_http_node_class = MagicMock()
|
||||
mock_http_node_class.get_default_config.return_value = {"type": "http-request", "config": {}}
|
||||
mock_llm_node_class = MagicMock()
|
||||
mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
|
||||
mock_mapping.items.return_value = [
|
||||
(NodeType.HTTP_REQUEST, {"latest": mock_http_node_class}),
|
||||
(NodeType.LLM, {"latest": mock_llm_node_class}),
|
||||
]
|
||||
|
||||
result = workflow_service.get_default_block_configs()
|
||||
|
||||
assert result == [
|
||||
{"type": "http-request", "config": {}},
|
||||
{"type": "llm", "config": {}},
|
||||
]
|
||||
mock_build_config.assert_called_once()
|
||||
passed_http_filters = mock_http_node_class.get_default_config.call_args.kwargs["filters"]
|
||||
assert passed_http_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
|
||||
mock_llm_node_class.get_default_config.assert_called_once_with(filters=None)
|
||||
|
||||
def test_get_default_block_config_for_node_type(self, workflow_service):
|
||||
"""
|
||||
Test get_default_block_config returns config for specific node type.
|
||||
@ -1121,6 +1161,84 @@ class TestWorkflowService:
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service):
|
||||
injected_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=11,
|
||||
max_read_timeout=22,
|
||||
max_write_timeout=33,
|
||||
max_binary_size=4096,
|
||||
max_text_size=2048,
|
||||
ssl_verify=False,
|
||||
ssrf_default_max_retries=7,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
|
||||
patch("services.workflow_service.LATEST_VERSION", "latest"),
|
||||
patch(
|
||||
"services.workflow_service.build_http_request_config",
|
||||
return_value=injected_config,
|
||||
) as mock_build_config,
|
||||
):
|
||||
mock_node_class = MagicMock()
|
||||
expected = {"type": "http-request", "config": {}}
|
||||
mock_node_class.get_default_config.return_value = expected
|
||||
mock_mapping.__contains__.return_value = True
|
||||
mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
|
||||
|
||||
result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value)
|
||||
|
||||
assert result == expected
|
||||
mock_build_config.assert_called_once()
|
||||
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
|
||||
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
|
||||
|
||||
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service):
|
||||
provided_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=13,
|
||||
max_read_timeout=23,
|
||||
max_write_timeout=34,
|
||||
max_binary_size=8192,
|
||||
max_text_size=4096,
|
||||
ssl_verify=True,
|
||||
ssrf_default_max_retries=2,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
|
||||
patch("services.workflow_service.LATEST_VERSION", "latest"),
|
||||
patch("services.workflow_service.build_http_request_config") as mock_build_config,
|
||||
):
|
||||
mock_node_class = MagicMock()
|
||||
expected = {"type": "http-request", "config": {}}
|
||||
mock_node_class.get_default_config.return_value = expected
|
||||
mock_mapping.__contains__.return_value = True
|
||||
mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
|
||||
|
||||
result = workflow_service.get_default_block_config(
|
||||
NodeType.HTTP_REQUEST.value,
|
||||
filters={HTTP_REQUEST_CONFIG_FILTER_KEY: provided_config},
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
mock_build_config.assert_not_called()
|
||||
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
|
||||
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is provided_config
|
||||
|
||||
def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service):
|
||||
with (
|
||||
patch(
|
||||
"services.workflow_service.NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}},
|
||||
),
|
||||
patch("services.workflow_service.LATEST_VERSION", "latest"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
|
||||
workflow_service.get_default_block_config(
|
||||
NodeType.HTTP_REQUEST.value,
|
||||
filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"},
|
||||
)
|
||||
|
||||
# ==================== Workflow Conversion Tests ====================
|
||||
# These tests verify converting basic apps to workflow apps
|
||||
|
||||
|
||||
@ -6,8 +6,8 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.variables.segments import ObjectSegment, StringSegment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
@ -174,7 +174,7 @@ class TestDraftVarLoaderSimple:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
from core.workflow.variables.segments import FloatSegment
|
||||
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
@ -224,7 +224,7 @@ class TestDraftVarLoaderSimple:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment
|
||||
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
@ -13,12 +13,11 @@ from core.app.app_config.entities import (
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import AppMode
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
@ -7,10 +7,10 @@ import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
@ -141,7 +141,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
@ -153,7 +153,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
@ -170,7 +170,7 @@ class TestDraftVariableSaver:
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
@ -222,7 +222,7 @@ class TestWorkflowDraftVariableService:
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result, autospec=True) as mock_reset_conv:
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
@ -330,8 +330,8 @@ class TestWorkflowDraftVariableService:
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
|
||||
@ -1,12 +1,7 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
@ -18,109 +13,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
compiled = call_args.compile()
|
||||
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
@ -136,135 +28,3 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 3
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user