Merge remote-tracking branch 'origin/main' into feat/trigger

# Conflicts:
#	api/docker/entrypoint.sh
#	api/uv.lock
#	dev/start-worker
#	docker/.env.example
#	docker/docker-compose.yaml
#	web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx
#	web/app/components/base/date-and-time-picker/date-picker/index.tsx
#	web/app/components/base/date-and-time-picker/types.ts
This commit is contained in:
Harry
2025-11-11 12:42:01 +08:00
97 changed files with 6372 additions and 1289 deletions

View File

@ -0,0 +1,301 @@
"""
Unit tests for TenantIsolatedTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import ValidationError
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()",
}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
# Serialize using TaskWrapper to get the correct format
wrapper = TaskWrapper(data=original_data)
serialized = wrapper.serialize()
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
# Pydantic will raise ValidationError for invalid JSON
with pytest.raises(ValidationError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantIsolatedTaskQueue:
"""Test cases for TenantIsolatedTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantIsolatedTaskQueue instance."""
return TenantIsolatedTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue._tenant_id == "tenant-123"
assert sample_queue._unique_key == "test-key"
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1,
)
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
@patch("core.rag.pipeline.queue.redis_client")
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
)
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is serialized as TaskWrapper JSON (without prefix)
# It should be a valid JSON string that can be deserialized by TaskWrapper
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
assert wrapper.data == {"object_task": "data", "id": 123}
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
task_data = {"task_id": 123, "data": "test"}
wrapper = TaskWrapper(data=task_data)
wrapped_task = wrapper.serialize()
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
None,
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid JSON falls back to string."""
# Invalid JSON string that cannot be deserialized
invalid_json = "invalid json data"
mock_redis.rpop.side_effect = [invalid_json, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_json]
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None,
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch("core.rag.pipeline.queue.redis_client")
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was serialized as TaskWrapper JSON
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
assert wrapped_task.startswith('{"data":')
# Verify it can be deserialized
wrapper = TaskWrapper.deserialize(wrapped_task)
assert wrapper.data == complex_task
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task

View File

@ -111,3 +111,26 @@ class TestVariablePoolGetAndNestedAttribute:
assert segment_false is not None
assert isinstance(segment_false, BooleanSegment)
assert segment_false.value is False
class TestVariablePoolGetNotModifyVariableDictionary:
_NODE_ID = "start"
_VAR_NAME = "name"
def test_convert_to_template_should_not_introduce_extra_keys(self):
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], 0)
pool.convert_template("The start.name is {{#start.name#}}")
assert "The start" not in pool.variable_dictionary
def test_get_should_not_modify_variable_dictionary(self):
pool = VariablePool.empty()
pool.get([self._NODE_ID, self._VAR_NAME])
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
assert "start" not in pool.variable_dictionary
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
pool.get([self._NODE_ID, "count"])
start_subdict = pool.variable_dictionary[self._NODE_ID]
assert "count" not in start_subdict

View File

@ -0,0 +1,514 @@
"""
Comprehensive unit tests for Redis broadcast channel implementation.
This test suite covers all aspects of the Redis broadcast channel including:
- Basic functionality and contract compliance
- Error handling and edge cases
- Thread safety and concurrency
- Resource management and cleanup
- Performance and reliability scenarios
"""
import dataclasses
import threading
import time
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
from libs.broadcast_channel.redis.channel import (
BroadcastChannel as RedisBroadcastChannel,
)
from libs.broadcast_channel.redis.channel import (
Topic,
_RedisSubscription,
)
class TestBroadcastChannel:
"""Test cases for the main BroadcastChannel class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
"""Create a mock Redis client for testing."""
client = MagicMock()
client.pubsub.return_value = MagicMock()
return client
@pytest.fixture
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
"""Create a BroadcastChannel instance with mock Redis client."""
return RedisBroadcastChannel(mock_redis_client)
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
"""Test that topic() method returns a Topic instance with correct parameters."""
topic_name = "test-topic"
topic = broadcast_channel.topic(topic_name)
assert isinstance(topic, Topic)
assert topic._client == mock_redis_client
assert topic._topic == topic_name
def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel):
"""Test that different topic names create isolated Topic instances."""
topic1 = broadcast_channel.topic("topic1")
topic2 = broadcast_channel.topic("topic2")
assert topic1 is not topic2
assert topic1._topic == "topic1"
assert topic2._topic == "topic2"
class TestTopic:
"""Test cases for the Topic class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
"""Create a mock Redis client for testing."""
client = MagicMock()
client.pubsub.return_value = MagicMock()
return client
@pytest.fixture
def topic(self, mock_redis_client: MagicMock) -> Topic:
"""Create a Topic instance for testing."""
return Topic(mock_redis_client, "test-topic")
def test_as_producer_returns_self(self, topic: Topic):
"""Test that as_producer() returns self as Producer interface."""
producer = topic.as_producer()
assert producer is topic
# Producer is a Protocol, check duck typing instead
assert hasattr(producer, "publish")
def test_as_subscriber_returns_self(self, topic: Topic):
"""Test that as_subscriber() returns self as Subscriber interface."""
subscriber = topic.as_subscriber()
assert subscriber is topic
# Subscriber is a Protocol, check duck typing instead
assert hasattr(subscriber, "subscribe")
def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock):
"""Test that publish() calls Redis PUBLISH with correct parameters."""
payload = b"test message"
topic.publish(payload)
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
@dataclasses.dataclass(frozen=True)
class SubscriptionTestCase:
"""Test case data for subscription tests."""
name: str
buffer_size: int
payload: bytes
expected_messages: list[bytes]
should_drop: bool = False
description: str = ""
class TestRedisSubscription:
"""Test cases for the _RedisSubscription class."""
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
pubsub = MagicMock()
pubsub.subscribe = MagicMock()
pubsub.unsubscribe = MagicMock()
pubsub.close = MagicMock()
pubsub.get_message = MagicMock()
return pubsub
@pytest.fixture
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
"""Create a _RedisSubscription instance for testing."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
yield subscription
subscription.close()
@pytest.fixture
def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription:
"""Create a subscription that has been started."""
subscription._start_if_needed()
return subscription
# ==================== Lifecycle Tests ====================
def test_subscription_initialization(self, mock_pubsub: MagicMock):
"""Test that subscription is properly initialized."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-topic"
assert not subscription._closed.is_set()
assert subscription._dropped_count == 0
assert subscription._listener_thread is None
assert not subscription._started
def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that _start_if_needed() properly starts subscription on first call."""
subscription._start_if_needed()
mock_pubsub.subscribe.assert_called_once_with("test-topic")
assert subscription._started is True
assert subscription._listener_thread is not None
def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription):
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
original_thread = started_subscription._listener_thread
started_subscription._start_if_needed()
# Should not create new thread or generator
assert started_subscription._listener_thread is original_thread
def test_start_if_needed_when_closed(self, subscription: _RedisSubscription):
"""Test that _start_if_needed() raises error when subscription is closed."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
subscription._start_if_needed()
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
"""Test that _start_if_needed() raises error when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
subscription._start_if_needed()
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that subscription works as context manager."""
with subscription as sub:
assert sub is subscription
assert subscription._started is True
mock_pubsub.subscribe.assert_called_once_with("test-topic")
def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that close() is idempotent and can be called multiple times."""
subscription._start_if_needed()
# Close multiple times
subscription.close()
subscription.close()
subscription.close()
# Should only cleanup once
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
mock_pubsub.close.assert_called_once()
assert subscription._pubsub is None
assert subscription._closed.is_set()
def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that close() properly cleans up all resources."""
subscription._start_if_needed()
thread = subscription._listener_thread
subscription.close()
# Verify cleanup
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
mock_pubsub.close.assert_called_once()
assert subscription._pubsub is None
assert subscription._listener_thread is None
# Wait for thread to finish (with timeout)
if thread and thread.is_alive():
thread.join(timeout=1.0)
assert not thread.is_alive()
# ==================== Message Processing Tests ====================
def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription):
"""Test message iterator behavior with messages in queue."""
test_messages = [b"msg1", b"msg2", b"msg3"]
# Add messages to queue
for msg in test_messages:
started_subscription._queue.put_nowait(msg)
# Iterate through messages
iterator = iter(started_subscription)
received_messages = []
for msg in iterator:
received_messages.append(msg)
if len(received_messages) >= len(test_messages):
break
assert received_messages == test_messages
def test_message_iterator_when_closed(self, subscription: _RedisSubscription):
"""Test that iterator raises error when subscription is closed."""
subscription.close()
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
iter(subscription)
# ==================== Message Enqueue Tests ====================
def test_enqueue_message_success(self, started_subscription: _RedisSubscription):
"""Test successful message enqueue."""
payload = b"test message"
started_subscription._enqueue_message(payload)
assert started_subscription._queue.qsize() == 1
assert started_subscription._queue.get_nowait() == payload
def test_enqueue_message_when_closed(self, subscription: _RedisSubscription):
"""Test message enqueue when subscription is closed."""
subscription.close()
payload = b"test message"
# Should not raise exception, but should not enqueue
subscription._enqueue_message(payload)
assert subscription._queue.empty()
def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription):
"""Test message enqueue with full queue (dropping behavior)."""
# Fill the queue
for i in range(started_subscription._queue.maxsize):
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
# Try to enqueue new message (should drop oldest)
new_message = b"new_message"
started_subscription._enqueue_message(new_message)
# Should have dropped one message and added new one
assert started_subscription._dropped_count == 1
# New message should be in queue
messages = []
while not started_subscription._queue.empty():
messages.append(started_subscription._queue.get_nowait())
assert new_message in messages
# ==================== Listener Thread Tests ====================
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
def test_listener_thread_normal_operation(
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
):
"""Test listener thread normal operation."""
# Mock message from Redis
mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"}
mock_pubsub.get_message.return_value = mock_message
# Start listener
subscription._start_if_needed()
# Wait a bit for processing
time.sleep(0.1)
# Verify message was processed
assert not subscription._queue.empty()
assert subscription._queue.get_nowait() == b"test payload"
def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread ignores subscribe/unsubscribe messages."""
mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1}
mock_pubsub.get_message.return_value = mock_message
subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue subscribe messages
assert subscription._queue.empty()
def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread ignores messages from wrong channels."""
mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"}
mock_pubsub.get_message.return_value = mock_message
subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue messages from wrong channels
assert subscription._queue.empty()
def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread handles Redis exceptions gracefully."""
mock_pubsub.get_message.side_effect = Exception("Redis error")
subscription._start_if_needed()
# Wait for thread to handle exception
time.sleep(0.2)
# Thread should still be alive but not processing
assert subscription._listener_thread is not None
assert not subscription._listener_thread.is_alive()
def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread stops when subscription is closed."""
subscription._start_if_needed()
thread = subscription._listener_thread
# Close subscription
subscription.close()
# Wait for thread to finish
if thread is not None and thread.is_alive():
thread.join(timeout=1.0)
assert thread is None or not thread.is_alive()
# ==================== Table-driven Tests ====================
@pytest.mark.parametrize(
"test_case",
[
SubscriptionTestCase(
name="basic_message",
buffer_size=5,
payload=b"hello world",
expected_messages=[b"hello world"],
description="Basic message publishing and receiving",
),
SubscriptionTestCase(
name="empty_message",
buffer_size=5,
payload=b"",
expected_messages=[b""],
description="Empty message handling",
),
SubscriptionTestCase(
name="large_message",
buffer_size=5,
payload=b"x" * 10000,
expected_messages=[b"x" * 10000],
description="Large message handling",
),
SubscriptionTestCase(
name="unicode_message",
buffer_size=5,
payload="你好世界".encode(),
expected_messages=["你好世界".encode()],
description="Unicode message handling",
),
],
)
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
"""Test various subscription scenarios using table-driven approach."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
# Simulate receiving message
mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload}
mock_pubsub.get_message.return_value = mock_message
try:
with subscription:
# Wait for message processing
time.sleep(0.1)
# Collect received messages
received = []
for msg in subscription:
received.append(msg)
if len(received) >= len(test_case.expected_messages):
break
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
finally:
subscription.close()
def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription):
"""Test concurrent close and enqueue operations."""
errors = []
def close_subscription():
try:
time.sleep(0.05) # Small delay
started_subscription.close()
except Exception as e:
errors.append(e)
def enqueue_messages():
try:
for i in range(50):
started_subscription._enqueue_message(f"msg_{i}".encode())
time.sleep(0.001)
except Exception as e:
errors.append(e)
# Start threads
close_thread = threading.Thread(target=close_subscription)
enqueue_thread = threading.Thread(target=enqueue_messages)
close_thread.start()
enqueue_thread.start()
# Wait for completion
close_thread.join(timeout=2.0)
enqueue_thread.join(timeout=2.0)
# Should not have any errors (operations should be safe)
assert len(errors) == 0
# ==================== Error Handling Tests ====================
def test_iterator_after_close(self, subscription: _RedisSubscription):
"""Test iterator behavior after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
iter(subscription)
def test_start_after_close(self, subscription: _RedisSubscription):
"""Test start attempts after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
subscription._start_if_needed()
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
"""Test operations when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
subscription._start_if_needed()
# Close should still work
subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
"""Test various channel name formats."""
channel_names = [
"simple",
"with-dashes",
"with_underscores",
"with.numbers",
"WITH.UPPERCASE",
"mixed-CASE_name",
"very.long.channel.name.with.multiple.parts",
]
for channel_name in channel_names:
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic=channel_name,
)
subscription._start_if_needed()
mock_pubsub.subscribe.assert_called_with(channel_name)
subscription.close()
def test_received_on_closed_subscription(self, subscription: _RedisSubscription):
subscription.close()
with pytest.raises(SubscriptionClosedError):
subscription.receive()

View File

@ -0,0 +1,317 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids

View File

@ -0,0 +1,483 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy._rag_pipeline_invoke_entities) == 3
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_isolated_task_queue should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
)