Merge commit '9c339239' into sandboxed-agent-rebase

Made-with: Cursor

# Conflicts:
#	api/README.md
#	api/controllers/console/app/workflow_draft_variable.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/plugin/backwards_invocation/model.py
#	api/core/prompt/advanced_prompt_transform.py
#	api/core/workflow/nodes/base/node.py
#	api/core/workflow/nodes/llm/llm_utils.py
#	api/core/workflow/nodes/llm/node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
#	api/core/workflow/runtime/graph_runtime_state.py
#	api/extensions/storage/base_storage.py
#	api/factories/variable_factory.py
#	api/pyproject.toml
#	api/services/variable_truncator.py
#	api/uv.lock
#	web/app/account/oauth/authorize/page.tsx
#	web/app/components/app/configuration/config-var/config-modal/field.tsx
#	web/app/components/base/alert.tsx
#	web/app/components/base/chat/chat/answer/human-input-content/executed-action.tsx
#	web/app/components/base/chat/chat/answer/more.tsx
#	web/app/components/base/chat/chat/answer/operation.tsx
#	web/app/components/base/chat/chat/answer/workflow-process.tsx
#	web/app/components/base/chat/chat/citation/index.tsx
#	web/app/components/base/chat/chat/citation/popup.tsx
#	web/app/components/base/chat/chat/citation/progress-tooltip.tsx
#	web/app/components/base/chat/chat/citation/tooltip.tsx
#	web/app/components/base/chat/chat/question.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/index.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx
#	web/app/components/base/markdown-blocks/form.tsx
#	web/app/components/base/prompt-editor/plugins/hitl-input-block/component-ui.tsx
#	web/app/components/base/tag-management/panel.tsx
#	web/app/components/base/tag-management/trigger.tsx
#	web/app/components/header/account-setting/index.tsx
#	web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
#	web/app/signin/utils/post-login-redirect.ts
#	web/eslint-suppressions.json
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Novice
2026-03-23 09:00:45 +08:00
1009 changed files with 76072 additions and 18166 deletions

View File

@ -19,7 +19,7 @@ class TestApiKeyAuthFactory:
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
"""Test getting auth factory for all valid providers"""
with patch(auth_class_path) as mock_auth:
with patch(auth_class_path, autospec=True) as mock_auth:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class == mock_auth
@ -46,7 +46,7 @@ class TestApiKeyAuthFactory:
(False, False),
],
)
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
def test_validate_credentials_delegates_to_auth_instance(
self, mock_get_factory, credentials_return_value, expected_result
):
@ -65,7 +65,7 @@ class TestApiKeyAuthFactory:
assert result is expected_result
mock_auth_instance.validate_credentials.assert_called_once()
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
"""Test that exceptions from auth instance are propagated"""
# Arrange

View File

@ -65,7 +65,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -96,7 +96,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -118,7 +118,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -145,7 +145,7 @@ class TestFirecrawlAuth:
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
@ -167,7 +167,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation and normalized"""
mock_response = MagicMock()
@ -185,7 +185,7 @@ class TestFirecrawlAuth:
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")

View File

@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
@ -130,7 +130,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = httpx.ConnectError("Network error")

View File

@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -132,7 +132,7 @@ class TestWatercrawlAuth:
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
@ -193,7 +193,7 @@ class TestWatercrawlAuth:
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")

View File

@ -1,932 +0,0 @@
"""
Comprehensive unit tests for DatasetCollectionBindingService.
This module contains extensive unit tests for the DatasetCollectionBindingService class,
which handles dataset collection binding operations for vector database collections.
The DatasetCollectionBindingService provides methods for:
- Retrieving or creating dataset collection bindings by provider, model, and type
- Retrieving specific collection bindings by ID and type
- Managing collection bindings for different collection types (dataset, etc.)
Collection bindings are used to map embedding models (provider + model name) to
specific vector database collections, allowing datasets to share collections when
they use the same embedding model configuration.
This test suite ensures:
- Correct retrieval of existing bindings
- Proper creation of new bindings when they don't exist
- Accurate filtering by provider, model, and collection type
- Proper error handling for missing bindings
- Database transaction handling (add, commit)
- Collection name generation using Dataset.gen_collection_name_by_id
================================================================================
ARCHITECTURE OVERVIEW
================================================================================
The DatasetCollectionBindingService is a critical component in the Dify platform's
vector database management system. It serves as an abstraction layer between the
application logic and the underlying vector database collections.
Key Concepts:
1. Collection Binding: A mapping between an embedding model configuration
(provider + model name) and a vector database collection name. This allows
multiple datasets to share the same collection when they use identical
embedding models, improving resource efficiency.
2. Collection Type: Different types of collections can exist (e.g., "dataset",
"custom_type"). This allows for separation of collections based on their
intended use case or data structure.
3. Provider and Model: The combination of provider_name (e.g., "openai",
"cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
uniquely identifies an embedding model configuration.
4. Collection Name Generation: When a new binding is created, a unique collection
name is generated using Dataset.gen_collection_name_by_id() with a UUID.
This ensures each binding has a unique collection identifier.
================================================================================
TESTING STRATEGY
================================================================================
This test suite follows a comprehensive testing strategy that covers:
1. Happy Path Scenarios:
- Successful retrieval of existing bindings
- Successful creation of new bindings
- Proper handling of default parameters
2. Edge Cases:
- Different collection types
- Various provider/model combinations
- Default vs explicit parameter usage
3. Error Handling:
- Missing bindings (for get_by_id_and_type)
- Database query failures
- Invalid parameter combinations
4. Database Interaction:
- Query construction and execution
- Transaction management (add, commit)
- Query chaining (where, order_by, first)
5. Mocking Strategy:
- Database session mocking
- Query builder chain mocking
- UUID generation mocking
- Collection name generation mocking
================================================================================
"""
"""
Import statements for the test module.
This section imports all necessary dependencies for testing the
DatasetCollectionBindingService, including:
- unittest.mock for creating mock objects
- pytest for test framework functionality
- uuid for UUID generation (used in collection name generation)
- Models and services from the application codebase
"""
from unittest.mock import Mock, patch
import pytest
from models.dataset import Dataset, DatasetCollectionBinding
from services.dataset_service import DatasetCollectionBindingService
# ============================================================================
# Test Data Factory
# ============================================================================
# The Test Data Factory pattern is used here to centralize the creation of
# test objects and mock instances. This approach provides several benefits:
#
# 1. Consistency: All test objects are created using the same factory methods,
# ensuring consistent structure across all tests.
#
# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
# changes, we only need to update the factory methods rather than every
# individual test.
#
# 3. Reusability: Factory methods can be reused across multiple test classes,
# reducing code duplication.
#
# 4. Readability: Tests become more readable when they use descriptive factory
# method calls instead of complex object construction logic.
#
# ============================================================================
class DatasetCollectionBindingTestDataFactory:
"""
Factory class for creating test data and mock objects for dataset collection binding tests.
This factory provides static methods to create mock objects for:
- DatasetCollectionBinding instances
- Database query results
- Collection name generation results
The factory methods help maintain consistency across tests and reduce
code duplication when setting up test scenarios.
"""
@staticmethod
def create_collection_binding_mock(
binding_id: str = "binding-123",
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
collection_type: str = "dataset",
created_at=None,
**kwargs,
) -> Mock:
"""
Create a mock DatasetCollectionBinding with specified attributes.
Args:
binding_id: Unique identifier for the binding
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
collection_name: Name of the vector database collection
collection_type: Type of collection (default: "dataset")
created_at: Optional datetime for creation timestamp
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a DatasetCollectionBinding instance
"""
binding = Mock(spec=DatasetCollectionBinding)
binding.id = binding_id
binding.provider_name = provider_name
binding.model_name = model_name
binding.collection_name = collection_name
binding.type = collection_type
binding.created_at = created_at
for key, value in kwargs.items():
setattr(binding, key, value)
return binding
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
**kwargs,
) -> Mock:
"""
Create a mock Dataset for testing collection name generation.
Args:
dataset_id: Unique identifier for the dataset
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Dataset instance
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
# ============================================================================
# Tests for get_dataset_collection_binding
# ============================================================================
class TestDatasetCollectionBindingServiceGetBinding:
"""
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
This test class covers the main collection binding retrieval/creation functionality,
including various provider/model combinations, collection types, and edge cases.
The get_dataset_collection_binding method:
1. Queries for existing binding by provider_name, model_name, and collection_type
2. Orders results by created_at (ascending) and takes the first match
3. If no binding exists, creates a new one with:
- The provided provider_name and model_name
- A generated collection_name using Dataset.gen_collection_name_by_id
- The provided collection_type
4. Adds the new binding to the database session and commits
5. Returns the binding (either existing or newly created)
Test scenarios include:
- Retrieving existing bindings
- Creating new bindings when none exist
- Different collection types
- Database transaction handling
- Collection name generation
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing database operations.
Provides a mocked database session that can be used to verify:
- Query construction and execution
- Add operations for new bindings
- Commit operations for transaction completion
The mock is configured to return a query builder that supports
chaining operations like .where(), .order_by(), and .first().
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
"""
Test successful retrieval of an existing collection binding.
Verifies that when a binding already exists in the database for the given
provider, model, and collection type, the method returns the existing binding
without creating a new one.
This test ensures:
- The query is constructed correctly with all three filters
- Results are ordered by created_at
- The first matching binding is returned
- No new binding is created (db.session.add is not called)
- No commit is performed (db.session.commit is not called)
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-123",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain: query().where().order_by().first()
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == "binding-123"
assert result.provider_name == provider_name
assert result.model_name == model_name
assert result.type == collection_type
# Verify query was constructed correctly
# The query should be constructed with DatasetCollectionBinding as the model
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied to filter by provider, model, and type
mock_query.where.assert_called_once()
# Verify the results were ordered by created_at (ascending)
# This ensures we get the oldest binding if multiple exist
mock_where.order_by.assert_called_once()
# Verify no new binding was created
# Since an existing binding was found, we should not create a new one
mock_db_session.add.assert_not_called()
# Verify no commit was performed
# Since no new binding was created, no database transaction is needed
mock_db_session.commit.assert_not_called()
def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
"""
Test successful creation of a new collection binding when none exists.
Verifies that when no binding exists in the database for the given
provider, model, and collection type, the method creates a new binding
with a generated collection name and commits it to the database.
This test ensures:
- The query returns None (no existing binding)
- A new DatasetCollectionBinding is created with correct attributes
- Dataset.gen_collection_name_by_id is called to generate collection name
- The new binding is added to the database session
- The transaction is committed
- The newly created binding is returned
"""
# Arrange
provider_name = "cohere"
model_name = "embed-english-v3.0"
collection_type = "dataset"
generated_collection_name = "collection-generated-xyz"
# Mock the query chain to return None (no existing binding)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No existing binding
mock_db_session.query.return_value = mock_query
# Mock Dataset.gen_collection_name_by_id to return a generated name
with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
mock_gen_name.return_value = generated_collection_name
# Mock uuid.uuid4 for the collection name generation
mock_uuid = "test-uuid-123"
with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result is not None
assert result.provider_name == provider_name
assert result.model_name == model_name
assert result.type == collection_type
assert result.collection_name == generated_collection_name
# Verify Dataset.gen_collection_name_by_id was called with the generated UUID
# This method generates a unique collection name based on the UUID
# The UUID is converted to string before passing to the method
mock_gen_name.assert_called_once_with(str(mock_uuid))
# Verify new binding was added to the database session
# The add method should be called exactly once with the new binding instance
mock_db_session.add.assert_called_once()
# Extract the binding that was added to verify its properties
added_binding = mock_db_session.add.call_args[0][0]
# Verify the added binding is an instance of DatasetCollectionBinding
# This ensures we're creating the correct type of object
assert isinstance(added_binding, DatasetCollectionBinding)
# Verify all the binding properties are set correctly
# These should match the input parameters to the method
assert added_binding.provider_name == provider_name
assert added_binding.model_name == model_name
assert added_binding.type == collection_type
# Verify the collection name was set from the generated name
# This ensures the binding has a valid collection identifier
assert added_binding.collection_name == generated_collection_name
# Verify the transaction was committed
# This ensures the new binding is persisted to the database
mock_db_session.commit.assert_called_once()
def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
"""
Test retrieval with a different collection type (not "dataset").
Verifies that the method correctly filters by collection_type, allowing
different types of collections to coexist with the same provider/model
combination.
This test ensures:
- Collection type is properly used as a filter in the query
- Different collection types can have separate bindings
- The correct binding is returned based on type
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "custom_type"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-456",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.type == collection_type
# Verify query was constructed with the correct type filter
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
"""
Test retrieval with default collection type ("dataset").
Verifies that when collection_type is not provided, it defaults to "dataset"
as specified in the method signature.
This test ensures:
- The default value "dataset" is used when type is not specified
- The query correctly filters by the default type
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
# collection_type defaults to "dataset" in method signature
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-789",
provider_name=provider_name,
model_name=model_name,
collection_type="dataset", # Default type
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act - call without specifying collection_type (uses default)
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name
)
# Assert
assert result == existing_binding
assert result.type == "dataset"
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
"""
Test retrieval with different provider/model combinations.
Verifies that bindings are correctly filtered by both provider_name and
model_name, ensuring that different model combinations have separate bindings.
This test ensures:
- Provider and model are both used as filters
- Different combinations result in different bindings
- The correct binding is returned for each combination
"""
# Arrange
provider_name = "huggingface"
model_name = "sentence-transformers/all-MiniLM-L6-v2"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-hf-123",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.provider_name == provider_name
assert result.model_name == model_name
# Verify query filters were applied correctly
# The query should filter by both provider_name and model_name
# This ensures different model combinations have separate bindings
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied with all three filters:
# - provider_name filter
# - model_name filter
# - collection_type filter
mock_query.where.assert_called_once()
# ============================================================================
# Tests for get_dataset_collection_binding_by_id_and_type
# ============================================================================
# This section contains tests for the get_dataset_collection_binding_by_id_and_type
# method, which retrieves a specific collection binding by its ID and type.
#
# Key differences from get_dataset_collection_binding:
# 1. This method queries by ID and type, not by provider/model/type
# 2. This method does NOT create a new binding if one doesn't exist
# 3. This method raises ValueError if the binding is not found
# 4. This method is typically used when you already know the binding ID
#
# Use cases:
# - Retrieving a binding that was previously created
# - Validating that a binding exists before using it
# - Accessing binding metadata when you have the ID
#
# ============================================================================
class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
"""
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
This test class covers collection binding retrieval by ID and type,
including success scenarios and error handling for missing bindings.
The get_dataset_collection_binding_by_id_and_type method:
1. Queries for a binding by collection_binding_id and collection_type
2. Orders results by created_at (ascending) and takes the first match
3. If no binding exists, raises ValueError("Dataset collection binding not found")
4. Returns the found binding
Unlike get_dataset_collection_binding, this method does NOT create a new
binding if one doesn't exist - it only retrieves existing bindings.
Test scenarios include:
- Successful retrieval of existing bindings
- Error handling for missing bindings
- Different collection types
- Default collection type behavior
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing database operations.
Provides a mocked database session that can be used to verify:
- Query construction with ID and type filters
- Ordering by created_at
- First result retrieval
The mock is configured to return a query builder that supports
chaining operations like .where(), .order_by(), and .first().
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
"""
Test successful retrieval of a collection binding by ID and type.
Verifies that when a binding exists in the database with the given
ID and collection type, the method returns the binding.
This test ensures:
- The query is constructed correctly with ID and type filters
- Results are ordered by created_at
- The first matching binding is returned
- No error is raised
"""
# Arrange
collection_binding_id = "binding-123"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_type=collection_type,
)
# Mock the query chain: query().where().order_by().first()
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == collection_type
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
mock_where.order_by.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
"""
Test error handling when binding is not found.
Verifies that when no binding exists in the database with the given
ID and collection type, the method raises a ValueError with the
message "Dataset collection binding not found".
This test ensures:
- The query returns None (no existing binding)
- ValueError is raised with the correct message
- No binding is returned
"""
# Arrange
collection_binding_id = "non-existent-binding"
collection_type = "dataset"
# Mock the query chain to return None (no existing binding)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No existing binding
mock_db_session.query.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Verify query was attempted
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
"""
Test retrieval with a different collection type.
Verifies that the method correctly filters by collection_type, ensuring
that bindings with the same ID but different types are treated as
separate entities.
This test ensures:
- Collection type is properly used as a filter in the query
- Different collection types can have separate bindings with same ID
- The correct binding is returned based on type
"""
# Arrange
collection_binding_id = "binding-456"
collection_type = "custom_type"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="cohere",
model_name="embed-english-v3.0",
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == collection_type
# Verify query was constructed with the correct type filter
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
"""
Test retrieval with default collection type ("dataset").
Verifies that when collection_type is not provided, it defaults to "dataset"
as specified in the method signature.
This test ensures:
- The default value "dataset" is used when type is not specified
- The query correctly filters by the default type
- The correct binding is returned
"""
# Arrange
collection_binding_id = "binding-789"
# collection_type defaults to "dataset" in method signature
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_type="dataset", # Default type
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act - call without specifying collection_type (uses default)
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == "dataset"
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
"""
Test error handling when binding exists but with wrong collection type.
Verifies that when a binding exists with the given ID but a different
collection type, the method raises a ValueError because the binding
doesn't match both the ID and type criteria.
This test ensures:
- The query correctly filters by both ID and type
- Bindings with matching ID but different type are not returned
- ValueError is raised when no matching binding is found
"""
# Arrange
collection_binding_id = "binding-123"
collection_type = "dataset"
# Mock the query chain to return None (binding exists but with different type)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No matching binding
mock_db_session.query.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Verify query was attempted with both ID and type filters
# The query should filter by both collection_binding_id and collection_type
# This ensures we only get bindings that match both criteria
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied with both filters:
# - collection_binding_id filter (exact match)
# - collection_type filter (exact match)
mock_query.where.assert_called_once()
# Note: The order_by and first() calls are also part of the query chain,
# but we don't need to verify them separately since they're part of the
# standard query pattern used by both methods in this service.
# ============================================================================
# Additional Test Scenarios and Edge Cases
# ============================================================================
# The following section could contain additional test scenarios if needed:
#
# Potential additional tests:
# 1. Test with multiple existing bindings (verify ordering by created_at)
# 2. Test with very long provider/model names (boundary testing)
# 3. Test with special characters in provider/model names
# 4. Test concurrent binding creation (thread safety)
# 5. Test database rollback scenarios
# 6. Test with None values for optional parameters
# 7. Test with empty strings for required parameters
# 8. Test collection name generation uniqueness
# 9. Test with different UUID formats
# 10. Test query performance with large datasets
#
# These scenarios are not currently implemented but could be added if needed
# based on real-world usage patterns or discovered edge cases.
#
# ============================================================================
# ============================================================================
# Integration Notes and Best Practices
# ============================================================================
#
# When using DatasetCollectionBindingService in production code, consider:
#
# 1. Error Handling:
# - Always handle ValueError exceptions when calling
# get_dataset_collection_binding_by_id_and_type
# - Check return values from get_dataset_collection_binding to ensure
# bindings were created successfully
#
# 2. Performance Considerations:
# - The service queries the database on every call, so consider caching
# bindings if they're accessed frequently
# - Collection bindings are typically long-lived, so caching is safe
#
# 3. Transaction Management:
# - New bindings are automatically committed to the database
# - If you need to rollback, ensure you're within a transaction context
#
# 4. Collection Type Usage:
# - Use "dataset" for standard dataset collections
# - Use custom types only when you need to separate collections by purpose
# - Be consistent with collection type naming across your application
#
# 5. Provider and Model Naming:
# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
# - Use exact model names as provided by the model provider
# - These names are case-sensitive and must match exactly
#
# ============================================================================
# ============================================================================
# Database Schema Reference
# ============================================================================
#
# The DatasetCollectionBinding model has the following structure:
#
# - id: StringUUID (primary key, auto-generated)
# - provider_name: String(255) (required, e.g., "openai", "cohere")
# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
# - type: String(40) (required, default: "dataset")
# - collection_name: String(64) (required, unique collection identifier)
# - created_at: DateTime (auto-generated timestamp)
#
# Indexes:
# - Primary key on id
# - Composite index on (provider_name, model_name) for efficient lookups
#
# Relationships:
# - One binding can be referenced by multiple datasets
# - Datasets reference bindings via collection_binding_id
#
# ============================================================================
# ============================================================================
# Mocking Strategy Documentation
# ============================================================================
#
# This test suite uses extensive mocking to isolate the unit under test.
# Here's how the mocking strategy works:
#
# 1. Database Session Mocking:
# - db.session is patched to prevent actual database access
# - Query chains are mocked to return predictable results
# - Add and commit operations are tracked for verification
#
# 2. Query Chain Mocking:
# - query() returns a mock query object
# - where() returns a mock where object
# - order_by() returns a mock order_by object
# - first() returns the final result (binding or None)
#
# 3. UUID Generation Mocking:
# - uuid.uuid4() is mocked to return predictable UUIDs
# - This ensures collection names are generated consistently in tests
#
# 4. Collection Name Generation Mocking:
# - Dataset.gen_collection_name_by_id() is mocked
# - This allows us to verify the method is called correctly
# - We can control the generated collection name for testing
#
# Benefits of this approach:
# - Tests run quickly (no database I/O)
# - Tests are deterministic (no random UUIDs)
# - Tests are isolated (no side effects)
# - Tests are maintainable (clear mock setup)
#
# ============================================================================

View File

@ -96,7 +96,6 @@ from unittest.mock import Mock, create_autospec, patch
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from models import Account, TenantAccountRole
from models.dataset import (
@ -536,421 +535,6 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset_id, update_data, user)
# ============================================================================
# Tests for delete_dataset
# ============================================================================
class TestDatasetServiceDeleteDataset:
"""
Comprehensive unit tests for DatasetService.delete_dataset method.
This test class covers the dataset deletion functionality, including
permission validation, event signaling, and database cleanup.
The delete_dataset method:
1. Retrieves the dataset by ID
2. Returns False if dataset not found
3. Validates user permissions
4. Sends dataset_was_deleted event
5. Deletes dataset from database
6. Commits transaction
7. Returns True on success
Test scenarios include:
- Successful dataset deletion
- Permission validation
- Event signaling
- Database cleanup
- Not found handling
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""
Mock dataset service dependencies for testing.
Provides mocked dependencies including:
- get_dataset method
- check_dataset_permission method
- dataset_was_deleted event signal
- Database session
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("services.dataset_service.dataset_was_deleted") as mock_event,
patch("extensions.ext_database.db.session") as mock_db,
):
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"dataset_was_deleted": mock_event,
"db_session": mock_db,
}
def test_delete_dataset_success(self, mock_dataset_service_dependencies):
"""
Test successful deletion of a dataset.
Verifies that when all validation passes, a dataset is deleted
correctly with proper event signaling and database cleanup.
This test ensures:
- Dataset is retrieved correctly
- Permission is checked
- Event is sent for cleanup
- Dataset is deleted from database
- Transaction is committed
- Method returns True
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
result = DatasetService.delete_dataset(dataset_id, user)
# Assert
assert result is True
# Verify dataset was retrieved
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
# Verify permission was checked
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify event was sent for cleanup
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
# Verify dataset was deleted and committed
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
"""
Test handling when dataset is not found.
Verifies that when the dataset ID doesn't exist, the method
returns False without performing any operations.
This test ensures:
- Method returns False when dataset not found
- No permission checks are performed
- No events are sent
- No database operations are performed
"""
# Arrange
dataset_id = "non-existent-dataset"
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = None
# Act
result = DatasetService.delete_dataset(dataset_id, user)
# Assert
assert result is False
# Verify no operations were performed
mock_dataset_service_dependencies["check_permission"].assert_not_called()
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
def test_delete_dataset_permission_denied_error(self, mock_dataset_service_dependencies):
"""
Test error handling when user lacks permission.
Verifies that when the user doesn't have permission to delete
the dataset, a NoPermissionError is raised.
This test ensures:
- Permission validation works correctly
- Error is raised before deletion
- No database operations are performed
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
# Act & Assert
with pytest.raises(NoPermissionError):
DatasetService.delete_dataset(dataset_id, user)
# Verify no deletion was attempted
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
# ============================================================================
# Tests for dataset_use_check
# ============================================================================
class TestDatasetServiceDatasetUseCheck:
"""
Comprehensive unit tests for DatasetService.dataset_use_check method.
This test class covers the dataset use checking functionality, which
determines if a dataset is currently being used by any applications.
The dataset_use_check method:
1. Queries AppDatasetJoin table for the dataset ID
2. Returns True if dataset is in use
3. Returns False if dataset is not in use
Test scenarios include:
- Dataset in use (has AppDatasetJoin records)
- Dataset not in use (no AppDatasetJoin records)
- Database query validation
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing.
Provides a mocked database session that can be used to verify
query construction and execution.
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_dataset_use_check_in_use(self, mock_db_session):
"""
Test detection when dataset is in use.
Verifies that when a dataset has associated AppDatasetJoin records,
the method returns True.
This test ensures:
- Query is constructed correctly
- True is returned when dataset is in use
- Database query is executed
"""
# Arrange
dataset_id = "dataset-123"
# Mock the exists() query to return True
mock_execute = Mock()
mock_execute.scalar_one.return_value = True
mock_db_session.execute.return_value = mock_execute
# Act
result = DatasetService.dataset_use_check(dataset_id)
# Assert
assert result is True
# Verify query was executed
mock_db_session.execute.assert_called_once()
def test_dataset_use_check_not_in_use(self, mock_db_session):
"""
Test detection when dataset is not in use.
Verifies that when a dataset has no associated AppDatasetJoin records,
the method returns False.
This test ensures:
- Query is constructed correctly
- False is returned when dataset is not in use
- Database query is executed
"""
# Arrange
dataset_id = "dataset-123"
# Mock the exists() query to return False
mock_execute = Mock()
mock_execute.scalar_one.return_value = False
mock_db_session.execute.return_value = mock_execute
# Act
result = DatasetService.dataset_use_check(dataset_id)
# Assert
assert result is False
# Verify query was executed
mock_db_session.execute.assert_called_once()
# ============================================================================
# Tests for update_dataset_api_status
# ============================================================================
class TestDatasetServiceUpdateDatasetApiStatus:
"""
Comprehensive unit tests for DatasetService.update_dataset_api_status method.
This test class covers the dataset API status update functionality,
which enables or disables API access for a dataset.
The update_dataset_api_status method:
1. Retrieves the dataset by ID
2. Validates dataset exists
3. Updates enable_api field
4. Updates updated_by and updated_at fields
5. Commits transaction
Test scenarios include:
- Successful API status enable
- Successful API status disable
- Dataset not found error
- Current user validation
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""
Mock dataset service dependencies for testing.
Provides mocked dependencies including:
- get_dataset method
- current_user context
- Database session
- Current time utilities
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
mock_current_user.id = "user-123"
yield {
"get_dataset": mock_get_dataset,
"current_user": mock_current_user,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
def test_update_dataset_api_status_enable_success(self, mock_dataset_service_dependencies):
"""
Test successful enabling of dataset API access.
Verifies that when all validation passes, the dataset's API
access is enabled and the update is committed.
This test ensures:
- Dataset is retrieved correctly
- enable_api is set to True
- updated_by and updated_at are set
- Transaction is committed
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id, enable_api=False)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
DatasetService.update_dataset_api_status(dataset_id, True)
# Assert
assert dataset.enable_api is True
assert dataset.updated_by == "user-123"
assert dataset.updated_at == mock_dataset_service_dependencies["current_time"]
# Verify dataset was retrieved
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
# Verify transaction was committed
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_update_dataset_api_status_disable_success(self, mock_dataset_service_dependencies):
"""
Test successful disabling of dataset API access.
Verifies that when all validation passes, the dataset's API
access is disabled and the update is committed.
This test ensures:
- Dataset is retrieved correctly
- enable_api is set to False
- updated_by and updated_at are set
- Transaction is committed
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id, enable_api=True)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
DatasetService.update_dataset_api_status(dataset_id, False)
# Assert
assert dataset.enable_api is False
assert dataset.updated_by == "user-123"
# Verify transaction was committed
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_update_dataset_api_status_not_found_error(self, mock_dataset_service_dependencies):
"""
Test error handling when dataset is not found.
Verifies that when the dataset ID doesn't exist, a NotFound
exception is raised.
This test ensures:
- NotFound exception is raised
- No updates are performed
- Error message is appropriate
"""
# Arrange
dataset_id = "non-existent-dataset"
mock_dataset_service_dependencies["get_dataset"].return_value = None
# Act & Assert
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(dataset_id, True)
# Verify no commit was attempted
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
def test_update_dataset_api_status_missing_current_user_error(self, mock_dataset_service_dependencies):
"""
Test error handling when current_user is missing.
Verifies that when current_user is None or has no ID, a ValueError
is raised.
This test ensures:
- ValueError is raised when current_user is None
- Error message is clear
- No updates are committed
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["current_user"].id = None # Missing user ID
# Act & Assert
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.update_dataset_api_status(dataset_id, True)
# Verify no commit was attempted
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
# ============================================================================
# Tests for update_rag_pipeline_dataset_settings
# ============================================================================
@ -1058,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model = "text-embedding-ada-002"
mock_embedding_model.model_name = "text-embedding-ada-002"
mock_embedding_model.provider = "openai"
mock_embedding_model.credentials = {}
mock_model_schema = Mock()
mock_model_schema.features = []
mock_text_embedding_model = Mock()
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
mock_embedding_model.model_type_instance = mock_text_embedding_model
mock_model_instance = Mock()
mock_model_instance.get_model_instance.return_value = mock_embedding_model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,141 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
def test_join_default_workspace_missing_required_fields_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "", "message": "ok"} # missing "joined"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
with pytest.raises(ValueError, match="Invalid response payload"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_joined_without_workspace_id_raises(self):
with pytest.raises(ValueError, match="workspace_id must be non-empty when joined is True"):
DefaultWorkspaceJoinResult(workspace_id="", joined=True, message="ok")
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")

View File

@ -27,7 +27,7 @@ class TestTraceparentPropagation:
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
with patch("services.enterprise.base.httpx.Client", autospec=True) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
@ -44,7 +44,9 @@ class TestTraceparentPropagation:
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
with patch(
"services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent, autospec=True
):
# Act
EnterpriseRequest.send_request("GET", "/test")

View File

@ -135,8 +135,8 @@ class TestExternalDatasetServiceGetExternalKnowledgeApis:
"""
with (
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
patch("services.external_knowledge_service.select"),
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
patch("services.external_knowledge_service.select", autospec=True),
):
yield mock_paginate
@ -245,7 +245,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Patch ``db.session`` for all CRUD tests in this class.
"""
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
@ -263,7 +263,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
}
# We do not want to actually call the remote endpoint here, so we patch the validator.
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
assert isinstance(result, ExternalKnowledgeApis)
@ -386,7 +386,7 @@ class TestExternalDatasetServiceUsageAndBindings:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
@ -447,7 +447,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
@ -520,7 +520,7 @@ class TestExternalDatasetServiceProcessExternalApi:
fake_response = httpx.Response(200)
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
mock_post.return_value = fake_response
result = ExternalDatasetService.process_external_api(settings, files=None)
@ -681,7 +681,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
@ -801,7 +801,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
@ -838,7 +838,9 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
with patch.object(
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
) as mock_process:
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=tenant_id,
dataset_id=dataset_id,
@ -908,7 +910,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
fake_response.status_code = 500
fake_response.json.return_value = {}
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",

View File

@ -146,7 +146,7 @@ class TestHitTestingServiceRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
@ -174,9 +174,11 @@ class TestHitTestingServiceRetrieve:
]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
mock_retrieve.return_value = documents
@ -218,9 +220,11 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
@ -268,10 +272,12 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
@ -311,8 +317,10 @@ class TestHitTestingServiceRetrieve:
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
with (
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
):
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
mock_format.return_value = []
@ -346,9 +354,11 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
@ -380,7 +390,7 @@ class TestHitTestingServiceExternalRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_external_retrieve_success(self, mock_db_session):
@ -403,8 +413,10 @@ class TestHitTestingServiceExternalRetrieve:
]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
@ -467,8 +479,10 @@ class TestHitTestingServiceExternalRetrieve:
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
@ -499,8 +513,10 @@ class TestHitTestingServiceExternalRetrieve:
metadata_filtering_conditions = {}
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = []
@ -542,7 +558,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
]
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
with patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format:
mock_format.return_value = mock_records
# Act
@ -566,7 +584,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
query = "test query"
documents = []
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
with patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format:
mock_format.return_value = []
# Act

View File

@ -147,7 +147,7 @@ class TestSegmentServiceCreateSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -172,10 +172,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -219,10 +221,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -257,11 +261,13 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.ModelManager") as mock_model_manager_class,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -292,10 +298,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -317,7 +325,7 @@ class TestSegmentServiceUpdateSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -338,10 +346,10 @@ class TestSegmentServiceUpdateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None # Not indexing
mock_hash.return_value = "new-hash"
@ -368,10 +376,10 @@ class TestSegmentServiceUpdateSegment:
args = SegmentUpdateArgs(enabled=False)
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
patch("services.dataset_service.disable_segment_from_index_task") as mock_task,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
@ -394,7 +402,7 @@ class TestSegmentServiceUpdateSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Indexing in progress
# Act & Assert
@ -409,7 +417,7 @@ class TestSegmentServiceUpdateSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = None
# Act & Assert
@ -427,10 +435,10 @@ class TestSegmentServiceUpdateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_hash.return_value = "new-hash"
@ -456,7 +464,7 @@ class TestSegmentServiceDeleteSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_segment_success(self, mock_db_session):
@ -471,10 +479,10 @@ class TestSegmentServiceDeleteSegment:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select,
):
mock_redis_get.return_value = None
mock_select.return_value.where.return_value = mock_select
@ -495,8 +503,8 @@ class TestSegmentServiceDeleteSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
):
mock_redis_get.return_value = None
@ -515,7 +523,7 @@ class TestSegmentServiceDeleteSegment:
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Deletion in progress
# Act & Assert
@ -529,7 +537,7 @@ class TestSegmentServiceDeleteSegments:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -562,8 +570,8 @@ class TestSegmentServiceDeleteSegments:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_select_func.return_value = mock_select
@ -594,7 +602,7 @@ class TestSegmentServiceUpdateSegmentsStatus:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -623,9 +631,9 @@ class TestSegmentServiceUpdateSegmentsStatus:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.enable_segments_to_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_select_func.return_value = mock_select
@ -657,10 +665,10 @@ class TestSegmentServiceUpdateSegmentsStatus:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.disable_segments_from_index_task") as mock_task,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
@ -693,7 +701,7 @@ class TestSegmentServiceGetSegments:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -771,7 +779,7 @@ class TestSegmentServiceGetSegmentById:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_segment_by_id_success(self, mock_db_session):
@ -814,7 +822,7 @@ class TestSegmentServiceGetChildChunks:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -876,7 +884,7 @@ class TestSegmentServiceGetChildChunkById:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_child_chunk_by_id_success(self, mock_db_session):
@ -919,7 +927,7 @@ class TestSegmentServiceCreateChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -942,9 +950,11 @@ class TestSegmentServiceCreateChildChunk:
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -972,9 +982,11 @@ class TestSegmentServiceCreateChildChunk:
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -994,7 +1006,7 @@ class TestSegmentServiceUpdateChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -1014,8 +1026,10 @@ class TestSegmentServiceUpdateChildChunk:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_now.return_value = "2024-01-01T00:00:00"
@ -1040,8 +1054,10 @@ class TestSegmentServiceUpdateChildChunk:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_vector_service.side_effect = Exception("Vector indexing failed")
mock_now.return_value = "2024-01-01T00:00:00"
@ -1059,7 +1075,7 @@ class TestSegmentServiceDeleteChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_child_chunk_success(self, mock_db_session):
@ -1068,7 +1084,9 @@ class TestSegmentServiceDeleteChildChunk:
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
# Act
SegmentService.delete_child_chunk(chunk, dataset)
@ -1083,7 +1101,9 @@ class TestSegmentServiceDeleteChildChunk:
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
mock_vector_service.side_effect = Exception("Vector deletion failed")
# Act & Assert

View File

@ -1064,6 +1064,67 @@ class TestRegisterService:
# ==================== Registration Tests ====================
def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
assert result == mock_account
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_not_called()
def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test successful account registration."""
# Setup mocks
@ -1115,6 +1176,65 @@ class TestRegisterService:
mock_event.send.assert_called_once_with(mock_tenant)
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_register_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked after successful register commit."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
assert result == mock_account
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_register_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
mock_join_default_workspace.assert_not_called()
def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test account registration with OAuth integration."""
# Setup mocks

View File

@ -63,3 +63,56 @@ def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
pause_state_config = call_kwargs.get("pause_state_config")
assert pause_state_config is not None
assert pause_state_config.state_owner_user_id == "owner-id"
def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch):
"""
Regression test: ADVANCED_CHAT in blocking mode should return a plain dict
(non-streaming), and must not go through the async retrieve_events path.
Keeps behavior consistent with WORKFLOW blocking branch.
"""
# Disable billing and stub RateLimit to a no-op that just passes values through
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
# Arrange a fake workflow and wire AppGenerateService._get_workflow to return it
workflow = MagicMock()
workflow.id = "workflow-id"
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
# Spy on the streaming retrieval path to ensure it's NOT called
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
# Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False
generate_spy = mocker.patch(
"services.app_generate_service.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
)
# Minimal app model for ADVANCED_CHAT
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
# Must include query and inputs for AdvancedChatAppGenerator
args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}}
# Act: call service with streaming=False (blocking mode)
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args=args,
invoke_from=MagicMock(),
streaming=False,
)
# Assert: returns the dict from generate(), and did not call retrieve_events()
assert result == {"result": "ok"}
assert generate_spy.call_args.kwargs.get("streaming") is False
retrieve_spy.assert_not_called()

View File

@ -44,9 +44,10 @@ class TestAppTaskService:
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
if should_call_graph_engine:
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
mock_graph_engine_manager.assert_called_once()
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
else:
mock_graph_engine_manager.send_stop_command.assert_not_called()
mock_graph_engine_manager.assert_not_called()
@pytest.mark.parametrize(
"invoke_from",
@ -76,7 +77,8 @@ class TestAppTaskService:
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
mock_graph_engine_manager.assert_called_once()
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
@patch("services.app_task_service.GraphEngineManager")
@patch("services.app_task_service.AppQueueManager")
@ -96,7 +98,7 @@ class TestAppTaskService:
app_mode = AppMode.ADVANCED_CHAT
# Simulate GraphEngine failure
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
mock_graph_engine_manager.return_value.send_stop_command.side_effect = Exception("GraphEngine error")
# Act & Assert - should raise the exception since it's not caught
with pytest.raises(Exception, match="GraphEngine error"):

View File

@ -15,8 +15,8 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
class TestWorkflowRunArchiver:
"""Tests for the WorkflowRunArchiver class."""
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
def test_archiver_initialization(self, mock_get_storage, mock_config):
"""Test archiver can be initialized with various options."""
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver

View File

@ -214,7 +214,7 @@ def factory():
class TestAudioServiceASR:
"""Test speech-to-text (ASR) operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in CHAT mode."""
# Arrange
@ -226,9 +226,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -242,7 +240,7 @@ class TestAudioServiceASR:
call_args = mock_model_instance.invoke_speech2text.call_args
assert call_args.kwargs["user"] == "user-123"
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
# Arrange
@ -254,9 +252,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -351,7 +347,7 @@ class TestAudioServiceASR:
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
AudioService.transcript_asr(app_model=app, file=file)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that ASR raises error when no model instance is available."""
# Arrange
@ -363,8 +359,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
@ -375,7 +370,7 @@ class TestAudioServiceASR:
class TestAudioServiceTTS:
"""Test text-to-speech (TTS) operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
"""Test successful TTS with text input."""
# Arrange
@ -388,9 +383,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -412,8 +405,8 @@ class TestAudioServiceTTS:
voice="en-US-Neural",
)
@patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID."""
# Arrange
@ -437,9 +430,7 @@ class TestAudioServiceTTS:
mock_query.first.return_value = message
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio from message"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -454,7 +445,7 @@ class TestAudioServiceTTS:
assert result == b"audio from message"
mock_model_instance.invoke_tts.assert_called_once()
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
"""Test TTS uses default voice when none specified."""
# Arrange
@ -467,9 +458,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -486,7 +475,7 @@ class TestAudioServiceTTS:
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "default-voice"
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
"""Test TTS gets first available voice when none is configured."""
# Arrange
@ -499,9 +488,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
mock_model_instance.invoke_tts.return_value = b"audio data"
@ -518,8 +505,8 @@ class TestAudioServiceTTS:
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "auto-voice"
@patch("services.audio_service.WorkflowService")
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.WorkflowService", autospec=True)
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_workflow_mode_with_draft(
self, mock_model_manager_class, mock_workflow_service_class, factory
):
@ -533,14 +520,11 @@ class TestAudioServiceTTS:
)
# Mock WorkflowService
mock_workflow_service = MagicMock()
mock_workflow_service_class.return_value = mock_workflow_service
mock_workflow_service = mock_workflow_service_class.return_value
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"draft audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -565,7 +549,7 @@ class TestAudioServiceTTS:
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format."""
# Arrange
@ -580,7 +564,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist."""
# Arrange
@ -601,7 +585,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty."""
# Arrange
@ -627,7 +611,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
"""Test that TTS raises error when no voices are available."""
# Arrange
@ -640,9 +624,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [] # No voices available
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -655,7 +637,7 @@ class TestAudioServiceTTS:
class TestAudioServiceTTSVoices:
"""Test TTS voice listing operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
"""Test successful retrieval of TTS voices."""
# Arrange
@ -668,9 +650,7 @@ class TestAudioServiceTTSVoices:
]
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = expected_voices
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -682,7 +662,7 @@ class TestAudioServiceTTSVoices:
assert result == expected_voices
mock_model_instance.get_tts_voices.assert_called_once_with(language)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that TTS voices raises error when no model instance is available."""
# Arrange
@ -690,15 +670,14 @@ class TestAudioServiceTTSVoices:
language = "en-US"
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
"""Test that TTS voices propagates exceptions from model instance."""
# Arrange
@ -706,9 +685,7 @@ class TestAudioServiceTTSVoices:
language = "en-US"
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
mock_model_manager.get_default_model_instance.return_value = mock_model_instance

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model
@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
assert result.embedding_model == embedding_model.model_name
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)

View File

@ -1,472 +0,0 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

View File

@ -1,746 +0,0 @@
"""
Comprehensive unit tests for DatasetService retrieval/list methods.
This test suite covers:
- get_datasets - pagination, search, filtering, permissions
- get_dataset - single dataset retrieval
- get_datasets_by_ids - bulk retrieval
- get_process_rules - dataset processing rules
- get_dataset_queries - dataset query history
- get_related_apps - apps using the dataset
"""
from unittest.mock import Mock, create_autospec, patch
from uuid import uuid4
import pytest
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
)
from services.dataset_service import DatasetService, DocumentService
class DatasetRetrievalTestDataFactory:
"""Factory class for creating test data and mock objects for dataset retrieval tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
name: str = "Test Dataset",
tenant_id: str = "tenant-123",
created_by: str = "user-123",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.name = name
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.permission = permission
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_account_mock(
account_id: str = "account-123",
tenant_id: str = "tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
**kwargs,
) -> Mock:
"""Create a mock account."""
account = create_autospec(Account, instance=True)
account.id = account_id
account.current_tenant_id = tenant_id
account.current_role = role
for key, value in kwargs.items():
setattr(account, key, value)
return account
@staticmethod
def create_dataset_permission_mock(
dataset_id: str = "dataset-123",
account_id: str = "account-123",
**kwargs,
) -> Mock:
"""Create a mock dataset permission."""
permission = Mock(spec=DatasetPermission)
permission.dataset_id = dataset_id
permission.account_id = account_id
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
@staticmethod
def create_process_rule_mock(
dataset_id: str = "dataset-123",
mode: str = "automatic",
rules: dict | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset process rule."""
process_rule = Mock(spec=DatasetProcessRule)
process_rule.dataset_id = dataset_id
process_rule.mode = mode
process_rule.rules_dict = rules or {}
for key, value in kwargs.items():
setattr(process_rule, key, value)
return process_rule
@staticmethod
def create_dataset_query_mock(
dataset_id: str = "dataset-123",
query_id: str = "query-123",
**kwargs,
) -> Mock:
"""Create a mock dataset query."""
dataset_query = Mock(spec=DatasetQuery)
dataset_query.id = query_id
dataset_query.dataset_id = dataset_id
for key, value in kwargs.items():
setattr(dataset_query, key, value)
return dataset_query
@staticmethod
def create_app_dataset_join_mock(
app_id: str = "app-123",
dataset_id: str = "dataset-123",
**kwargs,
) -> Mock:
"""Create a mock app-dataset join."""
join = Mock(spec=AppDatasetJoin)
join.app_id = app_id
join.dataset_id = dataset_id
for key, value in kwargs.items():
setattr(join, key, value)
return join
class TestDatasetServiceGetDatasets:
"""
Comprehensive unit tests for DatasetService.get_datasets method.
This test suite covers:
- Pagination
- Search functionality
- Tag filtering
- Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
- Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL)
- include_all flag
"""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_datasets tests."""
with (
patch("services.dataset_service.db.session") as mock_db,
patch("services.dataset_service.db.paginate") as mock_paginate,
patch("services.dataset_service.TagService") as mock_tag_service,
):
yield {
"db_session": mock_db,
"paginate": mock_paginate,
"tag_service": mock_tag_service,
}
# ==================== Basic Retrieval Tests ====================
def test_get_datasets_basic_pagination(self, mock_dependencies):
"""Test basic pagination without user or filters."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id
)
for i in range(5)
]
mock_paginate_result.total = 5
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id)
# Assert
assert len(datasets) == 5
assert total == 5
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_with_search(self, mock_dependencies):
"""Test get_datasets with search keyword."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
search = "test"
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search)
# Assert
assert len(datasets) == 1
assert total == 1
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_with_tag_filtering(self, mock_dependencies):
"""Test get_datasets with tag_ids filtering."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
tag_ids = ["tag-1", "tag-2"]
# Mock tag service
target_ids = ["dataset-1", "dataset-2"]
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
for dataset_id in target_ids
]
mock_paginate_result.total = 2
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
# Assert
assert len(datasets) == 2
assert total == 2
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with(
"knowledge", tenant_id, tag_ids
)
def test_get_datasets_with_empty_tag_ids(self, mock_dependencies):
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
tag_ids = []
# Mock pagination result - when tag_ids is empty, tag filtering is skipped
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
# Assert
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
assert len(datasets) == 3
assert total == 3
# Tag service should not be called when tag_ids is empty
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called()
mock_dependencies["paginate"].assert_called_once()
# ==================== Permission-Based Filtering Tests ====================
def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies):
"""Test that without user, only ALL_TEAM datasets are shown."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None)
# Assert
assert len(datasets) == 1
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_owner_with_include_all(self, mock_dependencies):
"""Test that OWNER with include_all=True sees all datasets."""
# Arrange
tenant_id = str(uuid4())
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER
)
# Mock dataset permissions query (empty - owner doesn't need explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True
)
# Assert
assert len(datasets) == 3
assert total == 3
def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies):
"""Test that normal user sees ONLY_ME datasets they created."""
# Arrange
tenant_id = str(uuid4())
user_id = "user-123"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query (no explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
created_by=user_id,
permission=DatasetPermissionEnum.ONLY_ME,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies):
"""Test that normal user sees ALL_TEAM datasets."""
# Arrange
tenant_id = str(uuid4())
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query (no explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies):
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
# Arrange
tenant_id = str(uuid4())
user_id = "user-123"
dataset_id = "dataset-1"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query - user has permission
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset_id, account_id=user_id
)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = [permission]
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id=dataset_id,
tenant_id=tenant_id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies):
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
# Arrange
tenant_id = str(uuid4())
user_id = "operator-123"
dataset_id = "dataset-1"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
)
# Mock dataset permissions query - operator has permission
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset_id, account_id=user_id
)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = [permission]
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies):
"""Test that DATASET_OPERATOR without permissions returns empty result."""
# Arrange
tenant_id = str(uuid4())
user_id = "operator-123"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
)
# Mock dataset permissions query - no permissions
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert datasets == []
assert total == 0
class TestDatasetServiceGetDataset:
"""Comprehensive unit tests for DatasetService.get_dataset method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_dataset tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_dataset_success(self, mock_dependencies):
"""Test successful retrieval of a single dataset."""
# Arrange
dataset_id = str(uuid4())
dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = dataset
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_dataset(dataset_id)
# Assert
assert result is not None
assert result.id == dataset_id
mock_query.filter_by.assert_called_once_with(id=dataset_id)
def test_get_dataset_not_found(self, mock_dependencies):
"""Test retrieval when dataset doesn't exist."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning None
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_dataset(dataset_id)
# Assert
assert result is None
class TestDatasetServiceGetDatasetsByIds:
"""Comprehensive unit tests for DatasetService.get_datasets_by_ids method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_datasets_by_ids tests."""
with patch("services.dataset_service.db.paginate") as mock_paginate:
yield {"paginate": mock_paginate}
def test_get_datasets_by_ids_success(self, mock_dependencies):
"""Test successful bulk retrieval of datasets by IDs."""
# Arrange
tenant_id = str(uuid4())
dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
for dataset_id in dataset_ids
]
mock_paginate_result.total = len(dataset_ids)
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
# Assert
assert len(datasets) == 3
assert total == 3
assert all(dataset.id in dataset_ids for dataset in datasets)
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_by_ids_empty_list(self, mock_dependencies):
"""Test get_datasets_by_ids with empty list returns empty result."""
# Arrange
tenant_id = str(uuid4())
dataset_ids = []
# Act
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
# Assert
assert datasets == []
assert total == 0
mock_dependencies["paginate"].assert_not_called()
def test_get_datasets_by_ids_none_list(self, mock_dependencies):
"""Test get_datasets_by_ids with None returns empty result."""
# Arrange
tenant_id = str(uuid4())
# Act
datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id)
# Assert
assert datasets == []
assert total == 0
mock_dependencies["paginate"].assert_not_called()
class TestDatasetServiceGetProcessRules:
"""Comprehensive unit tests for DatasetService.get_process_rules method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_process_rules tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_process_rules_with_existing_rule(self, mock_dependencies):
"""Test retrieval of process rules when rule exists."""
# Arrange
dataset_id = str(uuid4())
rules_data = {
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
"segmentation": {"delimiter": "\n", "max_tokens": 500},
}
process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock(
dataset_id=dataset_id, mode="custom", rules=rules_data
)
# Mock database query
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_process_rules(dataset_id)
# Assert
assert result["mode"] == "custom"
assert result["rules"] == rules_data
def test_get_process_rules_without_existing_rule(self, mock_dependencies):
"""Test retrieval of process rules when no rule exists (returns defaults)."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning None
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_process_rules(dataset_id)
# Assert
assert result["mode"] == DocumentService.DEFAULT_RULES["mode"]
assert "rules" in result
assert result["rules"] == DocumentService.DEFAULT_RULES["rules"]
class TestDatasetServiceGetDatasetQueries:
"""Comprehensive unit tests for DatasetService.get_dataset_queries method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_dataset_queries tests."""
with patch("services.dataset_service.db.paginate") as mock_paginate:
yield {"paginate": mock_paginate}
def test_get_dataset_queries_success(self, mock_dependencies):
"""Test successful retrieval of dataset queries."""
# Arrange
dataset_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}")
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
# Assert
assert len(queries) == 3
assert total == 3
assert all(query.dataset_id == dataset_id for query in queries)
mock_dependencies["paginate"].assert_called_once()
def test_get_dataset_queries_empty_result(self, mock_dependencies):
"""Test retrieval when no queries exist."""
# Arrange
dataset_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result (empty)
mock_paginate_result = Mock()
mock_paginate_result.items = []
mock_paginate_result.total = 0
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
# Assert
assert queries == []
assert total == 0
class TestDatasetServiceGetRelatedApps:
"""Comprehensive unit tests for DatasetService.get_related_apps method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_related_apps tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_related_apps_success(self, mock_dependencies):
"""Test successful retrieval of related apps."""
# Arrange
dataset_id = str(uuid4())
# Mock app-dataset joins
app_joins = [
DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id)
for i in range(2)
]
# Mock database query
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.all.return_value = app_joins
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_related_apps(dataset_id)
# Assert
assert len(result) == 2
assert all(join.dataset_id == dataset_id for join in result)
mock_query.where.assert_called_once()
mock_query.where.return_value.order_by.assert_called_once()
def test_get_related_apps_empty_result(self, mock_dependencies):
"""Test retrieval when no related apps exist."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning empty list
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_related_apps(dataset_id)
# Assert
assert result == []

View File

@ -1,661 +0,0 @@
import datetime
from typing import Any
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, create_autospec, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetUpdateTestDataFactory:
"""Factory class for creating test data and mock objects for dataset update tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
retrieval_model: str = "old_model",
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.provider = provider
dataset.name = name
dataset.description = description
dataset.indexing_technique = indexing_technique
dataset.retrieval_model = retrieval_model
dataset.embedding_model_provider = embedding_model_provider
dataset.embedding_model = embedding_model
dataset.collection_binding_id = collection_binding_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(user_id: str = "user-789") -> Mock:
"""Create a mock user."""
user = Mock()
user.id = user_id
return user
@staticmethod
def create_external_binding_mock(
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
) -> Mock:
"""Create a mock external knowledge binding."""
binding = Mock(spec=ExternalKnowledgeBindings)
binding.external_knowledge_id = external_knowledge_id
binding.external_knowledge_api_id = external_knowledge_api_id
return binding
@staticmethod
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.provider = provider
return embedding_model
@staticmethod
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
"""Create a mock collection binding."""
binding = Mock()
binding.id = binding_id
return binding
@staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id
return current_user
class TestDatasetServiceUpdateDataset:
"""
Comprehensive unit tests for DatasetService.update_dataset method.
This test suite covers all supported scenarios including:
- External dataset updates
- Internal dataset updates with different indexing techniques
- Embedding model updates
- Permission checks
- Error conditions and edge cases
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
"has_dataset_same_name": has_dataset_same_name,
}
@pytest.fixture
def mock_external_provider_dependencies(self):
"""Mock setup for external provider tests."""
with patch("services.dataset_service.Session") as mock_session:
from extensions.ext_database import db
with patch.object(db.__class__, "engine", new_callable=Mock):
session_mock = Mock()
mock_session.return_value.__enter__.return_value = session_mock
yield session_mock
@pytest.fixture
def mock_internal_provider_dependencies(self):
"""Mock setup for internal provider tests."""
with (
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
mock_current_user.current_tenant_id = "tenant-123"
yield {
"model_manager": mock_model_manager,
"get_binding": mock_get_binding,
"task": mock_task,
"regenerate_task": mock_regenerate_task,
"current_user": mock_current_user,
}
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
"""Helper method to verify database update calls."""
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
mock_db.commit.assert_called_once()
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
"""Helper method to verify external dataset updates."""
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
if "external_knowledge_id" in update_data:
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
if "external_knowledge_api_id" in update_data:
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
# ==================== External Dataset Tests ====================
def test_update_external_dataset_success(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test successful update of external dataset."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
# Mock external knowledge binding query
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
update_data = {
"name": "new_name",
"description": "new_description",
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": "new_api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify dataset and binding updates
self._assert_external_dataset_update(dataset, binding, update_data)
# Verify database operations
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add.assert_any_call(dataset)
mock_db.add.assert_any_call(binding)
mock_db.commit.assert_called_once()
# Verify return value
assert result == dataset
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge id is required" in str(context.value)
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge api id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge api id is required" in str(context.value)
def test_update_external_dataset_binding_not_found_error(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test error when external knowledge binding is not found."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock external knowledge binding query returning None
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": "api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge binding not found" in str(context.value)
# ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
"""Test successful update of internal dataset with basic fields."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify permission check was called
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify database update was called with correct filtered data
expected_filtered_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
"""Test that None values are filtered out except for description field."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": None, # Should be included
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": None, # Should be filtered out
"embedding_model": None, # Should be filtered out
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with filtered data
expected_filtered_data = {
"name": "new_name",
"description": None, # Description should be included even if None
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
actual_call_args = mock_dataset_service_dependencies[
"db_session"
].query.return_value.filter_by.return_value.update.call_args[0][0]
# Remove timestamp for comparison as it's dynamic
del actual_call_args["updated_at"]
del expected_filtered_data["updated_at"]
assert actual_call_args == expected_filtered_data
# Verify return value
assert result == dataset
# ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to economy."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with embedding model fields cleared
expected_filtered_data = {
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_indexing_technique_to_high_quality(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to high_quality."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-ada-002",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-ada-002",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
# Verify return value
assert result == dataset
# ==================== Embedding Model Update Tests ====================
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with existing embedding model preserved
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_embedding_model_update(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset with new embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-3-small",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
# Verify regenerate summary index task was triggered (when embedding_model changes)
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
"dataset-123",
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing indexing technique."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"indexing_technique": "high_quality", # Same as current
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with correct data
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
# ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
"""Test error when dataset is not found."""
mock_dataset_service_dependencies["get_dataset"].return_value = None
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
"""Test error when user doesn't have permission."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(NoPermissionError):
DatasetService.update_dataset("dataset-123", update_data, user)
def test_update_internal_dataset_embedding_model_error(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test error when embedding model is not available."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock model manager to raise error
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
"No Embedding Model available"
)
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(Exception) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "No Embedding Model available".lower() in str(context.value).lower()

View File

@ -6,66 +6,6 @@ from unittest.mock import MagicMock, patch
class TestArchivedWorkflowRunDeletion:
def test_delete_by_run_id_returns_error_when_run_missing(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
session = MagicMock()
session.get.return_value = None
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 not found"
repo.get_archived_run_ids.assert_not_called()
def test_delete_by_run_id_returns_error_when_not_archived(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
repo.get_archived_run_ids.return_value = set()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
session = MagicMock()
session.get.return_value = run
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(deleter, "_delete_run") as mock_delete_run,
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 is not archived"
mock_delete_run.assert_not_called()
def test_delete_by_run_id_calls_delete_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@ -88,65 +28,20 @@ class TestArchivedWorkflowRunDeletion:
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker",
return_value=session_maker,
autospec=True,
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run,
patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True),
patch.object(
deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True
) as mock_delete_run,
):
result = deleter.delete_by_run_id("run-1")
assert result.success is True
mock_delete_run.assert_called_once_with(run)
def test_delete_batch_uses_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
run1 = MagicMock()
run1.id = "run-1"
run1.tenant_id = "tenant-1"
run2 = MagicMock()
run2.id = "run-2"
run2.tenant_id = "tenant-1"
repo.get_archived_runs_by_time_range.return_value = [run1, run2]
session = MagicMock()
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
start_date = MagicMock()
end_date = MagicMock()
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(
deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)]
) as mock_delete_run,
):
results = deleter.delete_batch(
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert len(results) == 2
repo.get_archived_runs_by_time_range.assert_called_once_with(
session=session,
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert mock_delete_run.call_count == 2
def test_delete_run_dry_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@ -155,26 +50,8 @@ class TestArchivedWorkflowRunDeletion:
run.id = "run-1"
run.tenant_id = "tenant-1"
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo:
result = deleter._delete_run(run)
assert result.success is True
mock_get_repo.assert_not_called()
def test_delete_run_calls_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
repo = MagicMock()
repo.delete_runs_with_related.return_value = {"runs": 1}
with patch.object(deleter, "_get_workflow_run_repo", return_value=repo):
result = deleter._delete_run(run)
assert result.success is True
assert result.deleted_counts == {"runs": 1}
repo.delete_runs_with_related.assert_called_once()

View File

@ -1,6 +1,3 @@
import sqlalchemy as sa
from models.dataset import Document
from services.dataset_service import DocumentService
@ -9,25 +6,3 @@ def test_normalize_display_status_alias_mapping():
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None
def test_build_display_status_filters_available():
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
def test_apply_display_status_filter_applies_when_status_present():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" in compiled
assert "documents.indexing_status = 'waiting'" in compiled
def test_apply_display_status_filter_returns_same_when_invalid():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" not in compiled

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import App, DefaultEndUserSessionID, EndUser
from models.model import App, EndUser
from services.end_user_service import EndUserService
@ -44,113 +44,6 @@ class TestEndUserServiceFactory:
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 01: Get or create with custom user_id
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_mock()
user_id = "custom-user-123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify the created user has correct attributes
added_user = mock_session.add.call_args[0][0]
assert added_user.tenant_id == app.tenant_id
assert added_user.app_id == app.id
assert added_user.session_id == user_id
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.is_anonymous is False
# Test 02: Get or create without user_id (default session)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_mock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
# Test 03: Get existing end user
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_mock()
user_id = "existing-user-123"
existing_user = factory.create_end_user_mock(
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result == existing_user
mock_session.add.assert_not_called() # Should not create new user
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
@ -167,226 +60,6 @@ class TestEndUserServiceGetOrCreateEndUserByType:
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 04: Create new end user with SERVICE_API type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.tenant_id == tenant_id
assert added_user.app_id == app_id
assert added_user.session_id == user_id
# Test 05: Create new end user with WEB_APP type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.WEB_APP
# Test 06: Upgrade legacy end user type
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
mock_session.commit.assert_called_once()
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
# Test 07: Get existing end user with matching type (no upgrade needed)
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.SERVICE_API
# No commit should be called (no type update needed)
mock_session.commit.assert_not_called()
mock_logger.info.assert_not_called()
# Test 08: Create anonymous user with default session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Test 09: Query ordering prioritizes matching type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
# Verify order_by was called (for type prioritization)
mock_query.order_by.assert_called_once()
# Test 10: Session context manager properly closes
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
@ -420,117 +93,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
# Verify context manager was entered and exited
mock_context.__enter__.assert_called_once()
mock_context.__exit__.assert_called_once()
# Test 11: External user ID matches session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "custom-external-id"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.external_user_id == user_id
assert added_user.session_id == user_id
# Test 12: Different InvokeFrom types
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
tenant_id = "tenant-123"
app_id = "app-456"
end_user_id = "end-user-789"
existing_user = MagicMock(spec=EndUser)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_user
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
assert result == existing_user
mock_session.query.assert_called_once_with(EndUser)
mock_query.where.assert_called_once()
assert len(mock_query.where.call_args[0]) == 3
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
assert result is None

View File

@ -1,61 +0,0 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": "workflow-run-1",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []

View File

@ -402,7 +402,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
@ -414,8 +414,8 @@ class TestCreateMessageCleanPolicy:
# Assert
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService")
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange
@ -554,7 +554,7 @@ class TestMessagesCleanServiceFromDays:
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
@ -586,7 +586,7 @@ class TestMessagesCleanServiceFromDays:
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
@ -613,7 +613,7 @@ class TestMessagesCleanServiceFromDays:
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta

View File

@ -134,8 +134,8 @@ def factory():
class TestRecommendedAppServiceGetApps:
"""Test get_recommended_apps_and_categories operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of recommended apps when apps are returned."""
# Arrange
@ -161,8 +161,8 @@ class TestRecommendedAppServiceGetApps:
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
"""Test fallback to builtin when no recommended apps are returned."""
# Arrange
@ -199,8 +199,8 @@ class TestRecommendedAppServiceGetApps:
# Verify fallback was called with en-US (hardcoded)
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
"""Test fallback when recommended_apps key is None."""
# Arrange
@ -232,8 +232,8 @@ class TestRecommendedAppServiceGetApps:
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
"""Test retrieval with different language codes."""
# Arrange
@ -262,8 +262,8 @@ class TestRecommendedAppServiceGetApps:
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
"""Test that correct factory is selected based on mode."""
# Arrange
@ -292,8 +292,8 @@ class TestRecommendedAppServiceGetApps:
class TestRecommendedAppServiceGetDetail:
"""Test get_recommend_app_detail operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of app detail."""
# Arrange
@ -324,8 +324,8 @@ class TestRecommendedAppServiceGetDetail:
assert result["name"] == "Productivity App"
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
"""Test app detail retrieval with different factory modes."""
# Arrange
@ -352,8 +352,8 @@ class TestRecommendedAppServiceGetDetail:
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
"""Test that None is returned when app is not found."""
# Arrange
@ -375,8 +375,8 @@ class TestRecommendedAppServiceGetDetail:
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
"""Test handling of empty dict response."""
# Arrange
@ -397,8 +397,8 @@ class TestRecommendedAppServiceGetDetail:
# Assert
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
"""Test app detail with complex model configuration."""
# Arrange

View File

@ -3,7 +3,6 @@ Unit tests for workflow run restore functionality.
"""
from datetime import datetime
from unittest.mock import MagicMock
class TestWorkflowRunRestore:
@ -36,30 +35,3 @@ class TestWorkflowRunRestore:
assert result["created_at"].year == 2024
assert result["created_at"].month == 1
assert result["name"] == "test"
def test_restore_table_records_returns_rowcount(self):
"""Restore should return inserted rowcount."""
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
session = MagicMock()
session.execute.return_value = MagicMock(rowcount=2)
restore = WorkflowRunRestore()
records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}]
restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0")
assert restored == 2
session.execute.assert_called_once()
def test_restore_table_records_unknown_table(self):
"""Unknown table names should be ignored gracefully."""
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
session = MagicMock()
restore = WorkflowRunRestore()
restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0")
assert restored == 0
session.execute.assert_not_called()

View File

@ -201,8 +201,8 @@ def factory():
class TestSavedMessageServicePagination:
"""Test saved message pagination operations."""
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an Account user."""
# Arrange
@ -247,8 +247,8 @@ class TestSavedMessageServicePagination:
include_ids=["msg-0", "msg-1", "msg-2"],
)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an EndUser."""
# Arrange
@ -301,8 +301,8 @@ class TestSavedMessageServicePagination:
with pytest.raises(ValueError, match="User is required"):
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with last_id parameter."""
# Arrange
@ -340,8 +340,8 @@ class TestSavedMessageServicePagination:
call_args = mock_message_pagination.call_args
assert call_args.kwargs["last_id"] == last_id
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination when user has no saved messages."""
# Arrange
@ -377,8 +377,8 @@ class TestSavedMessageServicePagination:
class TestSavedMessageServiceSave:
"""Test save message operations."""
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an Account user."""
# Arrange
@ -407,8 +407,8 @@ class TestSavedMessageServiceSave:
assert saved_message.created_by_role == "account"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an EndUser."""
# Arrange
@ -437,7 +437,7 @@ class TestSavedMessageServiceSave:
assert saved_message.created_by_role == "end_user"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_without_user_does_nothing(self, mock_db_session, factory):
"""Test that saving without user is a no-op."""
# Arrange
@ -451,8 +451,8 @@ class TestSavedMessageServiceSave:
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
"""Test that saving an already saved message is idempotent."""
# Arrange
@ -480,8 +480,8 @@ class TestSavedMessageServiceSave:
mock_db_session.commit.assert_not_called()
mock_get_message.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
"""Test that save validates message exists through MessageService."""
# Arrange
@ -508,7 +508,7 @@ class TestSavedMessageServiceSave:
class TestSavedMessageServiceDelete:
"""Test delete saved message operations."""
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_saved_message_for_account(self, mock_db_session, factory):
"""Test deleting a saved message for an Account user."""
# Arrange
@ -535,7 +535,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
"""Test deleting a saved message for an EndUser."""
# Arrange
@ -562,7 +562,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
"""Test that deleting without user is a no-op."""
# Arrange
@ -576,7 +576,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
"""Test that deleting a non-existent saved message is a no-op."""
# Arrange
@ -597,7 +597,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
"""Test that delete only removes the user's own saved message."""
# Arrange

View File

@ -315,7 +315,7 @@ class TestTagServiceRetrieval:
- get_tags_by_target_id: Get all tags bound to a specific target
"""
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
"""
Test retrieving tags with their binding counts.
@ -372,7 +372,7 @@ class TestTagServiceRetrieval:
# Verify database query was called
mock_db_session.query.assert_called_once()
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
"""
Test retrieving tags filtered by keyword (case-insensitive).
@ -426,7 +426,7 @@ class TestTagServiceRetrieval:
# 2. Additional WHERE clause for keyword filtering
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
"""
Test retrieving target IDs by tag IDs.
@ -482,7 +482,7 @@ class TestTagServiceRetrieval:
# Verify both queries were executed
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
"""
Test that empty tag_ids returns empty list.
@ -510,7 +510,7 @@ class TestTagServiceRetrieval:
assert results == [], "Should return empty list for empty input"
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_by_tag_name(self, mock_db_session, factory):
"""
Test retrieving tags by name.
@ -552,7 +552,7 @@ class TestTagServiceRetrieval:
assert len(results) == 1, "Should find exactly one tag"
assert results[0].name == tag_name, "Tag name should match"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
"""
Test that missing tag_type or tag_name returns empty list.
@ -580,7 +580,7 @@ class TestTagServiceRetrieval:
# Verify no database queries were executed
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_by_target_id(self, mock_db_session, factory):
"""
Test retrieving tags associated with a specific target.
@ -651,10 +651,10 @@ class TestTagServiceCRUD:
- get_tag_binding_count: Get count of bindings for a tag
"""
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.uuid.uuid4")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.uuid.uuid4", autospec=True)
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
Test creating a new tag.
@ -709,8 +709,8 @@ class TestTagServiceCRUD:
assert added_tag.created_by == "user-123", "Created by should match current user"
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory):
"""
Test that creating a tag with duplicate name raises ValueError.
@ -740,9 +740,9 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"):
TagService.save_tags(args)
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
Test updating a tag name.
@ -792,9 +792,9 @@ class TestTagServiceCRUD:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags_raises_error_for_duplicate_name(
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
):
@ -826,7 +826,7 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"):
TagService.update_tags(args, tag_id="tag-123")
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
"""
Test that updating a non-existent tag raises NotFound.
@ -848,8 +848,8 @@ class TestTagServiceCRUD:
mock_query.first.return_value = None
# Mock duplicate check and current_user
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[]):
with patch("services.tag_service.current_user") as mock_user:
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True):
with patch("services.tag_service.current_user", autospec=True) as mock_user:
mock_user.current_tenant_id = "tenant-123"
args = {"name": "New Name", "type": "app"}
@ -858,7 +858,7 @@ class TestTagServiceCRUD:
with pytest.raises(NotFound, match="Tag not found"):
TagService.update_tags(args, tag_id="nonexistent")
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_binding_count(self, mock_db_session, factory):
"""
Test getting the count of bindings for a tag.
@ -894,7 +894,7 @@ class TestTagServiceCRUD:
# Verify count matches expectation
assert result == expected_count, "Binding count should match"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag(self, mock_db_session, factory):
"""
Test deleting a tag and its bindings.
@ -950,7 +950,7 @@ class TestTagServiceCRUD:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
"""
Test that deleting a non-existent tag raises NotFound.
@ -996,9 +996,9 @@ class TestTagServiceBindings:
- check_target_exists: Validate target (dataset/app) existence
"""
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test creating tag bindings.
@ -1047,9 +1047,9 @@ class TestTagServiceBindings:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test that saving duplicate bindings is idempotent.
@ -1088,8 +1088,8 @@ class TestTagServiceBindings:
# Verify no new binding was added (idempotent)
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
"""
Test deleting a tag binding.
@ -1136,8 +1136,8 @@ class TestTagServiceBindings:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
"""
Test that deleting a non-existent binding is a no-op.
@ -1173,8 +1173,8 @@ class TestTagServiceBindings:
# Verify no commit was made (nothing changed)
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
"""
Test validating that a dataset target exists.
@ -1214,8 +1214,8 @@ class TestTagServiceBindings:
# Verify no exception was raised and query was executed
mock_db_session.query.assert_called_once(), "Should query database for dataset"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
"""
Test validating that an app target exists.
@ -1255,8 +1255,8 @@ class TestTagServiceBindings:
# Verify no exception was raised and query was executed
mock_db_session.query.assert_called_once(), "Should query database for app"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_raises_not_found_for_missing_dataset(
self, mock_db_session, mock_current_user, factory
):
@ -1287,8 +1287,8 @@ class TestTagServiceBindings:
with pytest.raises(NotFound, match="Dataset not found"):
TagService.check_target_exists("knowledge", "nonexistent")
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
"""
Test that missing app raises NotFound.

View File

@ -17,7 +17,9 @@ from uuid import uuid4
import pytest
from core.variables.segments import (
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.file.models import File
from core.workflow.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArraySegment,
@ -28,8 +30,6 @@ from core.variables.segments import (
ObjectSegment,
StringSegment,
)
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.file.models import File
from services.variable_truncator import (
DummyVariableTruncator,
MaxDepthExceededError,

View File

@ -87,7 +87,7 @@ class TestWebhookServiceUnit:
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
with patch.object(WebhookService, "_process_file_uploads", autospec=True) as mock_process_files:
mock_process_files.return_value = {"file": "mocked_file_obj"}
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
@ -123,8 +123,10 @@ class TestWebhookServiceUnit:
mock_file.to_dict.return_value = {"file": "data"}
with (
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
patch.object(
WebhookService, "_detect_binary_mimetype", return_value="text/plain", autospec=True
) as mock_detect,
patch.object(WebhookService, "_create_file_from_binary", autospec=True) as mock_create,
):
mock_create.return_value = mock_file
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
@ -168,7 +170,7 @@ class TestWebhookServiceUnit:
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
with patch("services.trigger.webhook_service.logger") as mock_logger:
with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger:
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
@ -245,15 +247,12 @@ class TestWebhookServiceUnit:
assert response_data[0]["id"] == 1
assert response_data[1]["id"] == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
@patch("services.trigger.webhook_service.file_factory", autospec=True)
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
"""Test successful file upload processing."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
@ -285,15 +284,12 @@ class TestWebhookServiceUnit:
assert mock_tool_file_manager.call_count == 2
assert mock_file_factory.build_from_mapping.call_count == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
@patch("services.trigger.webhook_service.file_factory", autospec=True)
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
"""Test file upload processing with errors."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
@ -544,8 +540,8 @@ class TestWebhookServiceUnit:
# Mock the WebhookService methods
with (
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
patch.object(WebhookService, "get_webhook_trigger_and_workflow", autospec=True) as mock_get_trigger,
patch.object(WebhookService, "extract_and_validate_webhook_data", autospec=True) as mock_extract,
):
mock_trigger = MagicMock()
mock_workflow = MagicMock()

View File

@ -124,7 +124,7 @@ class TestWorkflowRunService:
"""Create WorkflowRunService instance with mocked dependencies."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
return service
@ -135,7 +135,7 @@ class TestWorkflowRunService:
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(mock_engine)
return service
@ -146,7 +146,7 @@ class TestWorkflowRunService:
"""Test WorkflowRunService initialization with session_factory."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
@ -158,9 +158,11 @@ class TestWorkflowRunService:
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
with patch(
"services.workflow_run_service.sessionmaker", return_value=session_factory, autospec=True
) as mock_sessionmaker:
service = WorkflowRunService(mock_engine)
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)

View File

@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
from libs.datetime_utils import naive_utc_now
from models.model import App, AppMode
from models.workflow import Workflow, WorkflowType
@ -1078,13 +1079,52 @@ class TestWorkflowService:
mock_node_class = MagicMock()
mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
mock_mapping.values.return_value = [{"latest": mock_node_class}]
mock_mapping.items.return_value = [(NodeType.LLM, {"latest": mock_node_class})]
with patch("services.workflow_service.LATEST_VERSION", "latest"):
result = workflow_service.get_default_block_configs()
assert len(result) > 0
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service):
injected_config = HttpRequestNodeConfig(
max_connect_timeout=15,
max_read_timeout=25,
max_write_timeout=35,
max_binary_size=4096,
max_text_size=2048,
ssl_verify=True,
ssrf_default_max_retries=6,
)
with (
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
patch("services.workflow_service.LATEST_VERSION", "latest"),
patch(
"services.workflow_service.build_http_request_config",
return_value=injected_config,
) as mock_build_config,
):
mock_http_node_class = MagicMock()
mock_http_node_class.get_default_config.return_value = {"type": "http-request", "config": {}}
mock_llm_node_class = MagicMock()
mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
mock_mapping.items.return_value = [
(NodeType.HTTP_REQUEST, {"latest": mock_http_node_class}),
(NodeType.LLM, {"latest": mock_llm_node_class}),
]
result = workflow_service.get_default_block_configs()
assert result == [
{"type": "http-request", "config": {}},
{"type": "llm", "config": {}},
]
mock_build_config.assert_called_once()
passed_http_filters = mock_http_node_class.get_default_config.call_args.kwargs["filters"]
assert passed_http_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
mock_llm_node_class.get_default_config.assert_called_once_with(filters=None)
def test_get_default_block_config_for_node_type(self, workflow_service):
"""
Test get_default_block_config returns config for specific node type.
@ -1121,6 +1161,84 @@ class TestWorkflowService:
assert result == {}
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service):
injected_config = HttpRequestNodeConfig(
max_connect_timeout=11,
max_read_timeout=22,
max_write_timeout=33,
max_binary_size=4096,
max_text_size=2048,
ssl_verify=False,
ssrf_default_max_retries=7,
)
with (
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
patch("services.workflow_service.LATEST_VERSION", "latest"),
patch(
"services.workflow_service.build_http_request_config",
return_value=injected_config,
) as mock_build_config,
):
mock_node_class = MagicMock()
expected = {"type": "http-request", "config": {}}
mock_node_class.get_default_config.return_value = expected
mock_mapping.__contains__.return_value = True
mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value)
assert result == expected
mock_build_config.assert_called_once()
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service):
provided_config = HttpRequestNodeConfig(
max_connect_timeout=13,
max_read_timeout=23,
max_write_timeout=34,
max_binary_size=8192,
max_text_size=4096,
ssl_verify=True,
ssrf_default_max_retries=2,
)
with (
patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
patch("services.workflow_service.LATEST_VERSION", "latest"),
patch("services.workflow_service.build_http_request_config") as mock_build_config,
):
mock_node_class = MagicMock()
expected = {"type": "http-request", "config": {}}
mock_node_class.get_default_config.return_value = expected
mock_mapping.__contains__.return_value = True
mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
result = workflow_service.get_default_block_config(
NodeType.HTTP_REQUEST.value,
filters={HTTP_REQUEST_CONFIG_FILTER_KEY: provided_config},
)
assert result == expected
mock_build_config.assert_not_called()
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is provided_config
def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service):
with (
patch(
"services.workflow_service.NODE_TYPE_CLASSES_MAPPING",
{NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}},
),
patch("services.workflow_service.LATEST_VERSION", "latest"),
):
with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
workflow_service.get_default_block_config(
NodeType.HTTP_REQUEST.value,
filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"},
)
# ==================== Workflow Conversion Tests ====================
# These tests verify converting basic apps to workflow apps

View File

@ -6,8 +6,8 @@ from unittest.mock import Mock, patch
import pytest
from sqlalchemy import Engine
from core.variables.segments import ObjectSegment, StringSegment
from core.variables.types import SegmentType
from core.workflow.variables.segments import ObjectSegment, StringSegment
from core.workflow.variables.types import SegmentType
from models.model import UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from services.workflow_draft_variable_service import DraftVarLoader
@ -174,7 +174,7 @@ class TestDraftVarLoaderSimple:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import FloatSegment
from core.workflow.variables.segments import FloatSegment
mock_segment = FloatSegment(value=test_number)
mock_build_segment.return_value = mock_segment
@ -224,7 +224,7 @@ class TestDraftVarLoaderSimple:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import ArrayAnySegment
from core.workflow.variables.segments import ArrayAnySegment
mock_segment = ArrayAnySegment(value=test_array)
mock_build_segment.return_value = mock_segment

View File

@ -13,12 +13,11 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter

View File

@ -7,10 +7,10 @@ import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeType
from core.workflow.variables.segments import StringSegment
from core.workflow.variables.types import SegmentType
from libs.uuid_utils import uuidv7
from models.account import Account
from models.enums import DraftVariableType
@ -141,7 +141,7 @@ class TestDraftVariableSaver:
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
) as _mock_try_offload:
_mock_try_offload.return_value = None
mock_segment = StringSegment(value="small value")
@ -153,7 +153,7 @@ class TestDraftVariableSaver:
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
) as _mock_try_offload:
mock_segment = StringSegment(value="small value")
mock_draft_var_file = WorkflowDraftVariableFile(
@ -170,7 +170,7 @@ class TestDraftVariableSaver:
# Should not have large variable metadata
assert draft_var.file_id == mock_draft_var_file.id
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
"""Test complete save workflow."""
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
@ -222,7 +222,7 @@ class TestWorkflowDraftVariableService:
name="test_var",
value=StringSegment(value="reset_value"),
)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
with patch.object(service, "_reset_conv_var", return_value=expected_result, autospec=True) as mock_reset_conv:
result = service.reset_variable(workflow, variable)
mock_reset_conv.assert_called_once_with(workflow, variable)
@ -330,8 +330,8 @@ class TestWorkflowDraftVariableService:
# Mock workflow methods
mock_node_config = {"type": "test_node"}
with (
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True),
):
result = service._reset_node_var_or_sys_var(workflow, variable)

View File

@ -1,12 +1,7 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
@ -18,109 +13,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
@ -136,135 +28,3 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()