Merge branch 'deploy/enterprise' of https://github.com/langgenius/dify into deploy/enterprise

This commit is contained in:
GareArc
2026-02-09 01:57:13 -08:00
23 changed files with 863 additions and 607 deletions

View File

@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download
# Execute the task
# Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
with pytest.raises(ValueError, match="The CSV file is empty"):
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
# Verify error handling
# Check Redis cache was set to error status
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
# Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()

View File

@ -0,0 +1,182 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_update_task import document_indexing_update_task
class TestDocumentIndexingUpdateTask:
@pytest.fixture
def mock_external_dependencies(self):
"""Patch external collaborators used by the update task.
- IndexProcessorFactory.init_index_processor().clean(...)
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
yield {
"factory": mock_factory,
"processor": processor_instance,
"runner": mock_runner,
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Dataset and document
dataset = Dataset(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
document = Document(
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Segments
node_ids = []
for i in range(segment_count):
node_id = f"node-{i + 1}"
seg = DocumentSegment(
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=32),
answer=None,
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
created_by=account.id,
)
db_session_with_containers.add(seg)
node_ids.append(node_id)
db_session_with_containers.commit()
# Refresh to ensure ORM state
db_session_with_containers.refresh(dataset)
db_session_with_containers.refresh(document)
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining == 0
# Assert index processor clean was called with expected args
clean_call = mock_external_dependencies["processor"].clean.call_args
assert clean_call is not None
args, kwargs = clean_call
# args[0] is a Dataset instance (from another session) — validate by id
assert getattr(args[0], "id", None) == dataset.id
# args[1] should contain our node_ids
assert set(args[1]) == set(node_ids)
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
# Assert indexing runner invoked with the updated document
run_call = mock_external_dependencies["runner_instance"].run.call_args
assert run_call is not None
run_docs = run_call[0][0]
assert len(run_docs) == 1
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Indexing should still be triggered
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
# Neither processor nor runner should be called
mock_external_dependencies["processor"].clean.assert_not_called()
mock_external_dependencies["runner_instance"].run.assert_not_called()

View File

@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
sessions = [] # Track all created sessions
# Shared mock data that all sessions will access
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
def _exit_side_effect(*args, **kwargs):
session.close()
def create_session_side_effect():
session = MagicMock()
session.close = MagicMock()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
# Track commit calls
commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
@pytest.fixture
@ -252,18 +356,9 @@ class TestTaskEnqueuing:
use the deprecated function.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -304,21 +399,9 @@ class TestBatchProcessing:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create an iterator for documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -357,19 +440,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@ -407,19 +480,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@ -444,7 +507,10 @@ class TestBatchProcessing:
"""
# Arrange
document_ids = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -482,19 +548,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -528,19 +584,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -635,19 +681,9 @@ class TestErrorHandling:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True
@ -674,17 +710,9 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@ -708,17 +736,9 @@ class TestErrorHandling:
but not treated as a failure.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@ -853,17 +873,9 @@ class TestTaskCancellation:
Session cleanup should happen in finally block.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -883,17 +895,9 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error")
@ -962,6 +966,7 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document)
@ -971,21 +976,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create iterator that returns None for missing document
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False
@ -1273,19 +1246,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1321,19 +1284,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1415,17 +1368,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1465,17 +1410,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1555,19 +1492,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1612,19 +1539,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1826,19 +1733,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@ -1866,7 +1763,7 @@ class TestRobustness:
- No exceptions occur
Expected behavior:
- Database session is closed
- All database sessions are closed
- No connection leaks
"""
# Arrange
@ -1879,19 +1776,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1899,10 +1786,11 @@ class TestRobustness:
# Act
_document_indexing(dataset_id, document_ids)
# Assert
assert mock_db_session.close.called
# Verify close is called exactly once
assert mock_db_session.close.call_count == 1
# Assert - All created sessions should be closed
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
assert len(mock_db_session.all_sessions) >= 1
for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
"""

View File

@ -114,6 +114,21 @@ def mock_db_session():
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager to auto-commit on exit
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(*args, **kwargs):
# session.begin().__exit__() should commit if no exception
if args[0] is None: # No exception
session.commit()
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session