mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
fix(api): fix performance issue in ShardedRedisBroadcastChannel
This commit is contained in:
@ -8,6 +8,7 @@ from typing import Self
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis.client import PubSub
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
|
||||
@ -42,6 +42,7 @@ class Topic:
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
redis_client: Redis | RedisCluster,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
@ -40,6 +40,7 @@ class ShardedTopic:
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
@ -68,7 +69,15 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
||||
if isinstance(self._client, RedisCluster):
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=0.1,
|
||||
target_node=node,
|
||||
)
|
||||
else:
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
||||
Reference in New Issue
Block a user