mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
refactor: use session factory instead of call db.session directly (#31198)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -49,10 +49,14 @@ def pipeline_id():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session with query capabilities."""
|
||||
with patch("tasks.clean_dataset_task.db") as mock_db:
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
# context manager for create_session()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = mock_session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
# Setup query chain
|
||||
mock_query = MagicMock()
|
||||
@ -66,7 +70,10 @@ def mock_db_session():
|
||||
# Setup execute for JOIN queries
|
||||
mock_session.execute.return_value.all.return_value = []
|
||||
|
||||
yield mock_db
|
||||
# Yield an object with a `.session` attribute to keep tests unchanged
|
||||
wrapper = MagicMock()
|
||||
wrapper.session = mock_session
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -227,7 +234,9 @@ class TestBasicCleanup:
|
||||
|
||||
# Assert
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_deletes_related_records(
|
||||
@ -413,7 +422,9 @@ class TestErrorHandling:
|
||||
|
||||
# Assert - documents and segments should still be deleted
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_storage_delete_failure_continues(
|
||||
@ -461,7 +472,7 @@ class TestErrorHandling:
|
||||
[mock_segment], # segments
|
||||
]
|
||||
mock_get_image_upload_file_ids.return_value = [image_file_id]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
mock_storage.delete.side_effect = Exception("Storage service unavailable")
|
||||
|
||||
# Act
|
||||
@ -476,8 +487,9 @@ class TestErrorHandling:
|
||||
|
||||
# Assert - storage delete was attempted for image file
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
# Image file should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_database_error_rollback(
|
||||
self,
|
||||
@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_attachment_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Attachment file and binding are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_attachment_storage_failure(
|
||||
self,
|
||||
@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
|
||||
|
||||
# Assert - storage delete was attempted
|
||||
mock_storage.delete.assert_called_once()
|
||||
# Records should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Records are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@ -784,7 +799,7 @@ class TestUploadFileCleanup:
|
||||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
@ -798,7 +813,9 @@ class TestUploadFileCleanup:
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_handles_missing_upload_file(
|
||||
self,
|
||||
@ -832,7 +849,7 @@ class TestUploadFileCleanup:
|
||||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
@ -949,11 +966,11 @@ class TestImageFileCleanup:
|
||||
[mock_segment], # segments
|
||||
]
|
||||
|
||||
# Setup a mock query chain that returns files in sequence
|
||||
# Setup a mock query chain that returns files in batch (align with .in_().all())
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.first.side_effect = mock_image_files
|
||||
mock_where.all.return_value = mock_image_files
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
@ -966,10 +983,10 @@ class TestImageFileCleanup:
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_storage.delete.call_count == 2
|
||||
mock_storage.delete.assert_any_call("images/image-1.jpg")
|
||||
mock_storage.delete.assert_any_call("images/image-2.jpg")
|
||||
# Assert - each expected image key was deleted at least once
|
||||
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
|
||||
assert "images/image-1.jpg" in calls
|
||||
assert "images/image-2.jpg" in calls
|
||||
|
||||
def test_clean_dataset_task_handles_missing_image_file(
|
||||
self,
|
||||
@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
|
||||
]
|
||||
|
||||
# Image file not found
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
@ -1086,14 +1103,15 @@ class TestEdgeCases:
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert - all documents and segments should be deleted
|
||||
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
|
||||
delete_calls = mock_db_session.session.delete.call_args_list
|
||||
deleted_items = [call[0][0] for call in delete_calls]
|
||||
|
||||
for doc in mock_documents:
|
||||
assert doc in deleted_items
|
||||
for seg in mock_segments:
|
||||
assert seg in deleted_items
|
||||
# Verify a batch DELETE on document_segments occurred
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_document_with_empty_data_source_info(
|
||||
self,
|
||||
|
||||
@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock the db.session used in delete_account_task."""
|
||||
with patch("tasks.delete_account_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock session via session_factory.create_session()."""
|
||||
with patch("tasks.delete_account_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -109,13 +109,25 @@ def mock_document_segments(document_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
# Session should still be closed via context manager teardown
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Allow tests to observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# Dataset via query.first()
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# scalars() call sequence:
|
||||
# 1) documents list
|
||||
# 2..N) segments per document
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
# First call returns documents; subsequent calls return segments
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# First scalars() -> documents; subsequent -> empty segments
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
|
||||
|
||||
class TestDeleteDraftVariablesBatch:
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test successful deletion of draft variables in batches."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 100
|
||||
|
||||
# Mock database connection and engine
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock two batches of results, then empty
|
||||
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
|
||||
@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
select_result3.__iter__.return_value = iter([])
|
||||
|
||||
# Configure side effects in the correct order
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
select_result1, # First SELECT
|
||||
delete_result1, # First DELETE
|
||||
select_result2, # Second SELECT
|
||||
@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
|
||||
assert result == 150
|
||||
|
||||
# Verify database calls
|
||||
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
|
||||
# Verify offload cleanup was called for both batches with file_ids
|
||||
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
|
||||
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
|
||||
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
|
||||
|
||||
# Simplified verification - check that the right number of calls were made
|
||||
# and that the SQL queries contain the expected patterns
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
actual_calls = mock_session.execute.call_args_list
|
||||
for i, actual_call in enumerate(actual_calls):
|
||||
sql_text = str(actual_call[0][0])
|
||||
normalized = " ".join(sql_text.split())
|
||||
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
|
||||
# Verify it's a SELECT query that now includes file_id
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE app_id = :app_id" in sql_text
|
||||
assert "LIMIT :batch_size" in sql_text
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE app_id = :app_id" in normalized
|
||||
assert "LIMIT :batch_size" in normalized
|
||||
else: # DELETE calls (odd indices: 1, 3)
|
||||
# Verify it's a DELETE query
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "DELETE FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE id IN :ids" in sql_text
|
||||
assert "DELETE FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE id IN :ids" in normalized
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
app_id = "nonexistent-app-id"
|
||||
batch_size = 1000
|
||||
|
||||
# Mock database connection
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock empty result
|
||||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
mock_conn.execute.return_value = empty_result
|
||||
mock_session.execute.return_value = empty_result
|
||||
|
||||
result = delete_draft_variables_batch(app_id, batch_size)
|
||||
|
||||
assert result == 0
|
||||
assert mock_conn.execute.call_count == 1 # Only one select query
|
||||
assert mock_session.execute.call_count == 1 # Only one select query
|
||||
mock_offload_cleanup.assert_not_called() # No files to clean up
|
||||
|
||||
def test_delete_draft_variables_batch_invalid_batch_size(self):
|
||||
@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
|
||||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 50
|
||||
|
||||
# Mock database
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock one batch then empty
|
||||
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
|
||||
@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
# Select query result
|
||||
select_result,
|
||||
# Delete query result
|
||||
@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
|
||||
# Verify offload cleanup was called with file_ids
|
||||
if batch_file_ids:
|
||||
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
|
||||
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
|
||||
|
||||
# Verify logging calls
|
||||
assert mock_logging.info.call_count == 2
|
||||
@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
|
||||
# First call should be the SELECT query
|
||||
select_call_sql = str(actual_calls[0][0][0])
|
||||
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
|
||||
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
|
||||
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
|
||||
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
|
||||
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
|
||||
|
||||
# Second call should be DELETE upload_files
|
||||
delete_upload_call_sql = str(actual_calls[1][0][0])
|
||||
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
|
||||
assert "DELETE FROM upload_files" in delete_upload_call_sql
|
||||
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
|
||||
|
||||
# Third call should be DELETE workflow_draft_variable_files
|
||||
delete_variable_files_call_sql = str(actual_calls[2][0][0])
|
||||
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
|
||||
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
|
||||
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
|
||||
|
||||
|
||||
Reference in New Issue
Block a user