mirror of
https://github.com/langgenius/dify.git
synced 2026-03-25 00:07:56 +08:00
add a test to ensure ShardedBroadcastChannel Works as expected (vibe-kanban a25b7a4c)
I have fixed a bug in `_RedisShardedSubscription` in the latest commit. I need you to write an integration test case for this bug. You should write the test in TDD way. You are free to change the git head.
This commit is contained in:
@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from redis.cluster import RedisCluster
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
@ -332,3 +333,95 @@ class TestShardedRedisBroadcastChannelIntegration:
|
||||
# Verify subscriptions are cleaned up
|
||||
topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name)
|
||||
assert topic_subscribers_after == 0
|
||||
|
||||
|
||||
class TestShardedRedisBroadcastChannelClusterIntegration:
|
||||
"""Integration tests for sharded pub/sub with RedisCluster client."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_cluster_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis 7 container with cluster mode enabled."""
|
||||
command = (
|
||||
"redis-server --port 6379 "
|
||||
"--cluster-enabled yes "
|
||||
"--cluster-config-file nodes.conf "
|
||||
"--cluster-node-timeout 5000 "
|
||||
"--appendonly no "
|
||||
"--protected-mode no"
|
||||
)
|
||||
with RedisContainer(image="redis:7-alpine").with_command(command) as container:
|
||||
yield container
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls) -> str:
|
||||
return f"test_sharded_cluster_topic_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_single_node_cluster(host: str, port: int) -> None:
|
||||
client = redis.Redis(host=host, port=port, decode_responses=False)
|
||||
client.config_set("cluster-announce-ip", host)
|
||||
client.config_set("cluster-announce-port", port)
|
||||
slots = client.execute_command("CLUSTER", "SLOTS")
|
||||
if not slots:
|
||||
client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383)
|
||||
|
||||
deadline = time.time() + 5.0
|
||||
while time.time() < deadline:
|
||||
info = client.execute_command("CLUSTER", "INFO")
|
||||
info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info)
|
||||
if "cluster_state:ok" in info_text:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Redis cluster did not become ready in time")
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster:
|
||||
host = redis_cluster_container.get_container_host_ip()
|
||||
port = int(redis_cluster_container.get_exposed_port(6379))
|
||||
self._ensure_single_node_cluster(host, port)
|
||||
return RedisCluster(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel:
|
||||
return ShardedRedisBroadcastChannel(redis_cluster_client)
|
||||
|
||||
def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel):
|
||||
"""Ensure sharded subscription receives messages when using RedisCluster client."""
|
||||
topic_name = self._get_test_topic_name()
|
||||
message = b"cluster sharded message"
|
||||
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
ready_event = threading.Event()
|
||||
|
||||
def consumer_thread() -> list[bytes]:
|
||||
received = []
|
||||
try:
|
||||
_ = subscription.receive(0.01)
|
||||
except SubscriptionClosedError:
|
||||
return received
|
||||
ready_event.set()
|
||||
deadline = time.time() + 5.0
|
||||
while time.time() < deadline:
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
if msg is None:
|
||||
continue
|
||||
received.append(msg)
|
||||
break
|
||||
subscription.close()
|
||||
return received
|
||||
|
||||
def producer_thread():
|
||||
if not ready_event.wait(timeout=2.0):
|
||||
pytest.fail("subscriber did not become ready before publish")
|
||||
producer.publish(message)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
producer_future = executor.submit(producer_thread)
|
||||
|
||||
producer_future.result(timeout=5.0)
|
||||
received_messages = consumer_future.result(timeout=5.0)
|
||||
|
||||
assert received_messages == [message]
|
||||
|
||||
Reference in New Issue
Block a user