fix(api): fix performance issue in ShardedRedisBroadcastChannel

This commit is contained in:
QuantumGhost
2026-02-05 13:28:39 +08:00
parent f614153f30
commit f21782a9a3
5 changed files with 3381 additions and 13 deletions

View File

@ -8,6 +8,7 @@ from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis.client import PubSub
from redis import Redis, RedisCluster
_logger = logging.getLogger(__name__)
@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
def __init__(
self,
client: Redis | RedisCluster,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._client = client
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()

View File

@ -42,6 +42,7 @@ class Topic:
def subscribe(self) -> Subscription:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
def __init__(
self,
redis_client: Redis,
redis_client: Redis | RedisCluster,
):
self._client = redis_client
@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
@ -40,6 +40,7 @@ class ShardedTopic:
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)
@ -68,7 +69,15 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
if isinstance(self._client, RedisCluster):
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message(
ignore_subscribe_messages=False,
timeout=0.1,
target_node=node,
)
else:
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"