refactor: use session factory instead of call db.session directly (#31198)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-01-21 13:43:06 +08:00
committed by GitHub
parent 071bbc6d74
commit 121d301a41
48 changed files with 2788 additions and 2693 deletions

View File

@ -39,23 +39,22 @@ class TestCleanDatasetTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client
# Clear all test data
db.session.query(DatasetMetadataBinding).delete()
db.session.query(DatasetMetadata).delete()
db.session.query(AppDatasetJoin).delete()
db.session.query(DatasetQuery).delete()
db.session.query(DatasetProcessRule).delete()
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -103,10 +102,8 @@ class TestCleanDatasetTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -115,8 +112,8 @@ class TestCleanDatasetTask:
status="active",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account relationship
tenant_account_join = TenantAccountJoin(
@ -125,8 +122,8 @@ class TestCleanDatasetTask:
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
db_session_with_containers.add(tenant_account_join)
db_session_with_containers.commit()
return account, tenant
@ -155,10 +152,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -194,10 +189,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -232,10 +225,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
@ -267,10 +258,8 @@ class TestCleanDatasetTask:
used=False,
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@ -302,31 +291,29 @@ class TestCleanDatasetTask:
)
# Verify results
from extensions.ext_database import db
# Check that dataset-related data was cleaned up
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(documents) == 0
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(metadata) == 0
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
assert len(process_rules) == 0
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
assert len(app_joins) == 0
# Verify index processor was called
@ -378,9 +365,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Create dataset metadata and bindings
metadata = DatasetMetadata(
@ -403,11 +388,9 @@ class TestCleanDatasetTask:
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
from extensions.ext_database import db
db.session.add(metadata)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(metadata)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -421,22 +404,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify index processor was called
@ -489,12 +474,13 @@ class TestCleanDatasetTask:
mock_index_processor.clean.assert_called_once()
# Check that all data was cleaned up
from extensions.ext_database import db
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_segments) == 0
# Recreate data for next test case
@ -540,14 +526,13 @@ class TestCleanDatasetTask:
)
# Verify results - even with vector cleanup failure, documents and segments should be deleted
from extensions.ext_database import db
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@ -608,10 +593,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Mock the get_image_upload_file_ids function to return our image file IDs
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
@ -629,16 +612,18 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@ -745,22 +730,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify performance expectations
@ -808,9 +795,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock storage to raise exceptions
mock_storage = mock_external_service_dependencies["storage"]
@ -827,18 +812,13 @@ class TestCleanDatasetTask:
)
# Verify results
# Check that documents were still deleted despite storage failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite storage failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Note: When storage operations fail, database deletions may be rolled back by implementation.
# This test focuses on ensuring the task handles the exception and continues execution/logging.
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@ -890,10 +870,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document with special characters in name
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
@ -912,8 +890,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create segment with special characters and very long content
long_content = "Very long content " * 100 # Long content within reasonable limits
@ -934,8 +912,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Create upload file with special characters in name
special_filename = f"test_file_{special_content}.txt"
@ -952,14 +930,14 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
used=False,
)
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
# Update document with file reference
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Save upload file ID for verification
upload_file_id = upload_file.id
@ -975,8 +953,8 @@ class TestCleanDatasetTask:
special_metadata.id = str(uuid.uuid4())
special_metadata.created_at = datetime.now()
db.session.add(special_metadata)
db.session.commit()
db_session_with_containers.add(special_metadata)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -990,19 +968,19 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called

View File

@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database and Redis before each test to ensure isolation."""
from extensions.ext_database import db
# Clear all test data
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
status="normal",
plan="basic",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
)
# Mock global database session to simulate transaction issues
from extensions.ext_database import db
original_commit = db.session.commit
commit_called = False
def mock_commit():
nonlocal commit_called
if not commit_called:
commit_called = True
raise Exception("Database commit failed")
return original_commit()
db.session.commit = mock_commit
# Simulate an error during indexing to trigger rollback path
mock_processor = mock_external_service_dependencies["index_processor"]
mock_processor.load.side_effect = Exception("Simulated indexing error")
# Act: Execute the task
create_segment_to_index_task(segment.id)
@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
assert segment.disabled_at is not None
assert segment.error is not None
# Restore original commit method
db.session.commit = original_commit
def test_create_segment_to_index_metadata_validation(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Set the current tenant for the account
account.current_tenant = tenant
@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
built_in_field_enabled=False,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
segments.append(segment)
from extensions.ext_database import db
for segment in segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segments
@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock db.session.close to verify it's called
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify session was closed
mock_close.assert_called()
# Assert
assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
"""

View File

@ -6,7 +6,6 @@ from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
db_session_with_containers.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
db_session_with_containers.add(doc2)
extra_documents.append(doc2)
db.session.commit()
db_session_with_containers.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()

View File

@ -4,7 +4,6 @@ import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
# Refresh to ensure all relationships are loaded
for document in documents:
db.session.refresh(document)
db_session_with_containers.refresh(document)
return dataset, documents, segments
@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
)
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
# Verify index processor clean was called for each document with segments
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
# Verify segments were deleted from database
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
for segment in segments:
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
# Re-query segments from database using captured IDs to avoid stale ORM instances
for seg_id in segment_ids:
deleted_segment = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
)
assert deleted_segment is None
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_duplicate_document_indexing_task(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with documents that will exceed vector space limit
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")