refactor: implement tenant self queue for rag tasks

This commit is contained in:
hj24
2025-10-28 14:20:43 +08:00
parent 4a797ab2d8
commit 2c2b3092f6
24 changed files with 3667 additions and 92 deletions

View File

@ -0,0 +1,369 @@
"""
Unit tests for TenantSelfTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.rag.pipeline.queue import TASK_WRAPPER_PREFIX, TaskWrapper, TenantSelfTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {
"deep": {
"value": "test",
"numbers": [1, 2, 3, 4, 5]
}
},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()"
}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
serialized = json.dumps(original_data, ensure_ascii=False)
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
with pytest.raises(json.JSONDecodeError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantSelfTaskQueue:
"""Test cases for TenantSelfTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantSelfTaskQueue instance."""
return TenantSelfTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue.tenant_id == "tenant-123"
assert sample_queue.unique_key == "test-key"
assert sample_queue.queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue.task_key == "tenant_test-key_task:tenant-123"
assert sample_queue.DEFAULT_TASK_TTL == 60 * 60
@patch('core.rag.pipeline.queue.redis_client')
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1
)
@patch('core.rag.pipeline.queue.redis_client')
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
custom_ttl,
1
)
@patch('core.rag.pipeline.queue.redis_client')
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123",
"task1", "task2", "task3"
)
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = [
"string_task",
{"object_task": "data", "id": 123},
"another_string"
]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is wrapped
assert serialized_tasks[1].startswith(TASK_WRAPPER_PREFIX)
wrapper_data = serialized_tasks[1][len(TASK_WRAPPER_PREFIX):]
parsed_data = json.loads(wrapper_data)
assert parsed_data == {"object_task": "data", "id": 123}
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123"
)
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
wrapper = TaskWrapper({"task_id": 123, "data": "test"})
wrapped_task = f"{TASK_WRAPPER_PREFIX}{wrapper.serialize()}"
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode('utf-8'), # Simulate bytes from Redis
None
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid wrapped data falls back to string."""
invalid_wrapped = f"{TASK_WRAPPER_PREFIX}invalid json"
mock_redis.rpop.side_effect = [invalid_wrapped, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_wrapped]
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch('core.rag.pipeline.queue.redis_client')
def test_get_next_task_success(self, mock_redis, sample_queue):
"""Test getting next single task."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.get_next_task()
assert result == "task1"
assert mock_redis.rpop.call_count == 1
@patch('core.rag.pipeline.queue.redis_client')
def test_get_next_task_empty_queue(self, mock_redis, sample_queue):
"""Test getting next task when queue is empty."""
mock_redis.rpop.return_value = None
result = sample_queue.get_next_task()
assert result is None
mock_redis.rpop.assert_called_once()
def test_tenant_isolation(self):
"""Test that different tenants have isolated queues."""
queue1 = TenantSelfTaskQueue("tenant-1", "key")
queue2 = TenantSelfTaskQueue("tenant-2", "key")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == "tenant_self_key_task_queue:tenant-1"
assert queue2.queue == "tenant_self_key_task_queue:tenant-2"
def test_key_isolation(self):
"""Test that different keys have isolated queues."""
queue1 = TenantSelfTaskQueue("tenant", "key1")
queue2 = TenantSelfTaskQueue("tenant", "key2")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == "tenant_self_key1_task_queue:tenant"
assert queue2.queue == "tenant_self_key2_task_queue:tenant"
@patch('core.rag.pipeline.queue.redis_client')
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {
"nested": {
"deep": [1, 2, 3],
"unicode": "测试中文",
"special": "!@#$%^&*()"
}
},
"metadata": {
"created_at": "2024-01-01T00:00:00Z",
"tags": ["tag1", "tag2", "tag3"]
}
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was wrapped
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
assert wrapped_task.startswith(TASK_WRAPPER_PREFIX)
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task
@patch('core.rag.pipeline.queue.redis_client')
def test_large_task_list_handling(self, mock_redis, sample_queue):
"""Test handling of large task lists."""
large_task_list = [f"task_{i}" for i in range(1000)]
sample_queue.push_tasks(large_task_list)
# Verify all tasks were pushed
call_args = mock_redis.lpush.call_args
assert len(call_args[0]) == 1001 # queue name + 1000 tasks
assert list(call_args[0][1:]) == large_task_list

View File

@ -0,0 +1,332 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(
billing_enabled: bool = False,
plan: str = "sandbox"
) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantSelfTaskQueue."""
queue = Mock(spec=TenantSelfTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-456",
document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dateset_id == dataset_id # Note: typo in original code
assert proxy.document_ids == document_ids
assert isinstance(proxy.tenant_self_task_queue, TenantSelfTaskQueue)
assert proxy.tenant_self_task_queue.tenant_id == tenant_id
assert proxy.tenant_self_task_queue.unique_key == "document_indexing"
@patch('services.document_indexing_task_proxy.FeatureService')
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123",
dataset_id="dataset-456",
document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy.tenant_self_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy.tenant_self_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy.tenant_self_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy.tenant_self_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy.tenant_self_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123",
dataset_id="dataset-456",
document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy.tenant_self_task_queue.push_tasks.assert_not_called()
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.priority_document_indexing_task')
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.priority_document_indexing_task')
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="team"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=False
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dateset_id == dataset_id
assert proxy.document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dateset_id == dataset_id
assert proxy.document_ids == document_ids

View File

@ -0,0 +1,499 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(
billing_enabled: bool = False,
plan: str = "sandbox"
) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantSelfTaskQueue."""
queue = Mock(spec=TenantSelfTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()
]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy.dataset_tenant_id == dataset_tenant_id
assert proxy.user_id == user_id
assert proxy.rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy.tenant_self_pipeline_task_queue, TenantSelfTaskQueue)
assert proxy.tenant_self_pipeline_task_queue.tenant_id == dataset_tenant_id
assert proxy.tenant_self_pipeline_task_queue.unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy.dataset_tenant_id == dataset_tenant_id
assert proxy.user_id == user_id
assert proxy.rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3")
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy.rag_pipeline_invoke_entities) == 3
assert proxy.rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy.rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy.rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2")
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_self_pipeline_task_queue should not be called
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id="tenant-123"
)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy.tenant_self_pipeline_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_not_called()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task')
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task')
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="team"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=False
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.logger')
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
"tenant-123",
"user-456"
)