mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
Merge branch 'fix/redis-pubsub-perf' into feat/hitl
This commit is contained in:
@ -7,6 +7,7 @@ from typing import Self
|
|||||||
|
|
||||||
from libs.broadcast_channel.channel import Subscription
|
from libs.broadcast_channel.channel import Subscription
|
||||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||||
|
from redis import Redis, RedisCluster
|
||||||
from redis.client import PubSub
|
from redis.client import PubSub
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
client: Redis | RedisCluster,
|
||||||
pubsub: PubSub,
|
pubsub: PubSub,
|
||||||
topic: str,
|
topic: str,
|
||||||
):
|
):
|
||||||
# The _pubsub is None only if the subscription is closed.
|
# The _pubsub is None only if the subscription is closed.
|
||||||
|
self._client = client
|
||||||
self._pubsub: PubSub | None = pubsub
|
self._pubsub: PubSub | None = pubsub
|
||||||
self._topic = topic
|
self._topic = topic
|
||||||
self._closed = threading.Event()
|
self._closed = threading.Event()
|
||||||
|
|||||||
@ -42,6 +42,7 @@ class Topic:
|
|||||||
|
|
||||||
def subscribe(self) -> Subscription:
|
def subscribe(self) -> Subscription:
|
||||||
return _RedisSubscription(
|
return _RedisSubscription(
|
||||||
|
client=self._client,
|
||||||
pubsub=self._client.pubsub(),
|
pubsub=self._client.pubsub(),
|
||||||
topic=self._topic,
|
topic=self._topic,
|
||||||
)
|
)
|
||||||
@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
|
|||||||
|
|
||||||
def _get_message(self) -> dict | None:
|
def _get_message(self) -> dict | None:
|
||||||
assert self._pubsub is not None
|
assert self._pubsub is not None
|
||||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||||
|
|
||||||
def _get_message_type(self) -> str:
|
def _get_message_type(self) -> str:
|
||||||
return "message"
|
return "message"
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis, RedisCluster
|
||||||
|
|
||||||
from ._subscription import RedisSubscriptionBase
|
from ._subscription import RedisSubscriptionBase
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_client: Redis,
|
redis_client: Redis | RedisCluster,
|
||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
|
|||||||
|
|
||||||
|
|
||||||
class ShardedTopic:
|
class ShardedTopic:
|
||||||
def __init__(self, redis_client: Redis, topic: str):
|
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
self._topic = topic
|
self._topic = topic
|
||||||
|
|
||||||
@ -40,6 +40,7 @@ class ShardedTopic:
|
|||||||
|
|
||||||
def subscribe(self) -> Subscription:
|
def subscribe(self) -> Subscription:
|
||||||
return _RedisShardedSubscription(
|
return _RedisShardedSubscription(
|
||||||
|
client=self._client,
|
||||||
pubsub=self._client.pubsub(),
|
pubsub=self._client.pubsub(),
|
||||||
topic=self._topic,
|
topic=self._topic,
|
||||||
)
|
)
|
||||||
@ -68,7 +69,19 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||||||
#
|
#
|
||||||
# Since we have already filtered at the caller's site, we can safely set
|
# Since we have already filtered at the caller's site, we can safely set
|
||||||
# `ignore_subscribe_messages=False`.
|
# `ignore_subscribe_messages=False`.
|
||||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
if isinstance(self._client, RedisCluster):
|
||||||
|
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
|
||||||
|
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
|
||||||
|
#
|
||||||
|
# Here we specify the `target_node` to mitigate this problem.
|
||||||
|
node = self._client.get_node_from_key(self._topic)
|
||||||
|
return self._pubsub.get_sharded_message(
|
||||||
|
ignore_subscribe_messages=False,
|
||||||
|
timeout=1,
|
||||||
|
target_node=node,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||||
|
|
||||||
def _get_message_type(self) -> str:
|
def _get_message_type(self) -> str:
|
||||||
return "smessage"
|
return "smessage"
|
||||||
|
|||||||
@ -181,6 +181,7 @@ class TestShardedTopic:
|
|||||||
subscription = sharded_topic.subscribe()
|
subscription = sharded_topic.subscribe()
|
||||||
|
|
||||||
assert isinstance(subscription, _RedisShardedSubscription)
|
assert isinstance(subscription, _RedisShardedSubscription)
|
||||||
|
assert subscription._client is mock_redis_client
|
||||||
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
||||||
assert subscription._topic == "test-sharded-topic"
|
assert subscription._topic == "test-sharded-topic"
|
||||||
|
|
||||||
@ -200,6 +201,11 @@ class SubscriptionTestCase:
|
|||||||
class TestRedisSubscription:
|
class TestRedisSubscription:
|
||||||
"""Test cases for the _RedisSubscription class."""
|
"""Test cases for the _RedisSubscription class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_client(self) -> MagicMock:
|
||||||
|
client = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pubsub(self) -> MagicMock:
|
def mock_pubsub(self) -> MagicMock:
|
||||||
"""Create a mock PubSub instance for testing."""
|
"""Create a mock PubSub instance for testing."""
|
||||||
@ -211,9 +217,12 @@ class TestRedisSubscription:
|
|||||||
return pubsub
|
return pubsub
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
def subscription(
|
||||||
|
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||||
|
) -> Generator[_RedisSubscription, None, None]:
|
||||||
"""Create a _RedisSubscription instance for testing."""
|
"""Create a _RedisSubscription instance for testing."""
|
||||||
subscription = _RedisSubscription(
|
subscription = _RedisSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
)
|
)
|
||||||
@ -228,13 +237,15 @@ class TestRedisSubscription:
|
|||||||
|
|
||||||
# ==================== Lifecycle Tests ====================
|
# ==================== Lifecycle Tests ====================
|
||||||
|
|
||||||
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||||
"""Test that subscription is properly initialized."""
|
"""Test that subscription is properly initialized."""
|
||||||
subscription = _RedisSubscription(
|
subscription = _RedisSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert subscription._client is mock_redis_client
|
||||||
assert subscription._pubsub is mock_pubsub
|
assert subscription._pubsub is mock_pubsub
|
||||||
assert subscription._topic == "test-topic"
|
assert subscription._topic == "test-topic"
|
||||||
assert not subscription._closed.is_set()
|
assert not subscription._closed.is_set()
|
||||||
@ -486,9 +497,12 @@ class TestRedisSubscription:
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
def test_subscription_scenarios(
|
||||||
|
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||||
|
):
|
||||||
"""Test various subscription scenarios using table-driven approach."""
|
"""Test various subscription scenarios using table-driven approach."""
|
||||||
subscription = _RedisSubscription(
|
subscription = _RedisSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
)
|
)
|
||||||
@ -572,7 +586,7 @@ class TestRedisSubscription:
|
|||||||
# Close should still work
|
# Close should still work
|
||||||
subscription.close() # Should not raise
|
subscription.close() # Should not raise
|
||||||
|
|
||||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||||
"""Test various channel name formats."""
|
"""Test various channel name formats."""
|
||||||
channel_names = [
|
channel_names = [
|
||||||
"simple",
|
"simple",
|
||||||
@ -586,6 +600,7 @@ class TestRedisSubscription:
|
|||||||
|
|
||||||
for channel_name in channel_names:
|
for channel_name in channel_names:
|
||||||
subscription = _RedisSubscription(
|
subscription = _RedisSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic=channel_name,
|
topic=channel_name,
|
||||||
)
|
)
|
||||||
@ -604,6 +619,11 @@ class TestRedisSubscription:
|
|||||||
class TestRedisShardedSubscription:
|
class TestRedisShardedSubscription:
|
||||||
"""Test cases for the _RedisShardedSubscription class."""
|
"""Test cases for the _RedisShardedSubscription class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_client(self) -> MagicMock:
|
||||||
|
client = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pubsub(self) -> MagicMock:
|
def mock_pubsub(self) -> MagicMock:
|
||||||
"""Create a mock PubSub instance for testing."""
|
"""Create a mock PubSub instance for testing."""
|
||||||
@ -615,9 +635,12 @@ class TestRedisShardedSubscription:
|
|||||||
return pubsub
|
return pubsub
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
|
def sharded_subscription(
|
||||||
|
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||||
|
) -> Generator[_RedisShardedSubscription, None, None]:
|
||||||
"""Create a _RedisShardedSubscription instance for testing."""
|
"""Create a _RedisShardedSubscription instance for testing."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-sharded-topic",
|
topic="test-sharded-topic",
|
||||||
)
|
)
|
||||||
@ -634,13 +657,15 @@ class TestRedisShardedSubscription:
|
|||||||
|
|
||||||
# ==================== Lifecycle Tests ====================
|
# ==================== Lifecycle Tests ====================
|
||||||
|
|
||||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
|
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||||
"""Test that sharded subscription is properly initialized."""
|
"""Test that sharded subscription is properly initialized."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-sharded-topic",
|
topic="test-sharded-topic",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert subscription._client is mock_redis_client
|
||||||
assert subscription._pubsub is mock_pubsub
|
assert subscription._pubsub is mock_pubsub
|
||||||
assert subscription._topic == "test-sharded-topic"
|
assert subscription._topic == "test-sharded-topic"
|
||||||
assert not subscription._closed.is_set()
|
assert not subscription._closed.is_set()
|
||||||
@ -808,6 +833,37 @@ class TestRedisShardedSubscription:
|
|||||||
assert not sharded_subscription._queue.empty()
|
assert not sharded_subscription._queue.empty()
|
||||||
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
|
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
|
||||||
|
|
||||||
|
def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch):
|
||||||
|
"""Test that cluster clients use target_node for sharded messages."""
|
||||||
|
|
||||||
|
class DummyRedisCluster:
|
||||||
|
def __init__(self):
|
||||||
|
self.get_node_from_key = MagicMock(return_value="node-1")
|
||||||
|
|
||||||
|
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster)
|
||||||
|
|
||||||
|
client = DummyRedisCluster()
|
||||||
|
subscription = _RedisShardedSubscription(
|
||||||
|
client=client,
|
||||||
|
pubsub=mock_pubsub,
|
||||||
|
topic="test-sharded-topic",
|
||||||
|
)
|
||||||
|
mock_pubsub.get_sharded_message.return_value = {
|
||||||
|
"type": "smessage",
|
||||||
|
"channel": "test-sharded-topic",
|
||||||
|
"data": b"payload",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = subscription._get_message()
|
||||||
|
|
||||||
|
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
|
||||||
|
mock_pubsub.get_sharded_message.assert_called_once_with(
|
||||||
|
ignore_subscribe_messages=False,
|
||||||
|
timeout=0.1,
|
||||||
|
target_node="node-1",
|
||||||
|
)
|
||||||
|
assert result == mock_pubsub.get_sharded_message.return_value
|
||||||
|
|
||||||
def test_listener_thread_ignores_subscribe_messages(
|
def test_listener_thread_ignores_subscribe_messages(
|
||||||
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
|
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
|
||||||
):
|
):
|
||||||
@ -913,9 +969,12 @@ class TestRedisShardedSubscription:
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
def test_sharded_subscription_scenarios(
|
||||||
|
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||||
|
):
|
||||||
"""Test various sharded subscription scenarios using table-driven approach."""
|
"""Test various sharded subscription scenarios using table-driven approach."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic="test-sharded-topic",
|
topic="test-sharded-topic",
|
||||||
)
|
)
|
||||||
@ -999,7 +1058,7 @@ class TestRedisShardedSubscription:
|
|||||||
# Close should still work
|
# Close should still work
|
||||||
sharded_subscription.close() # Should not raise
|
sharded_subscription.close() # Should not raise
|
||||||
|
|
||||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||||
"""Test various sharded channel name formats."""
|
"""Test various sharded channel name formats."""
|
||||||
channel_names = [
|
channel_names = [
|
||||||
"simple",
|
"simple",
|
||||||
@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription:
|
|||||||
|
|
||||||
for channel_name in channel_names:
|
for channel_name in channel_names:
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic=channel_name,
|
topic=channel_name,
|
||||||
)
|
)
|
||||||
@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon:
|
|||||||
"""Parameterized fixture providing subscription type and class."""
|
"""Parameterized fixture providing subscription type and class."""
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_client(self) -> MagicMock:
|
||||||
|
client = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pubsub(self) -> MagicMock:
|
def mock_pubsub(self) -> MagicMock:
|
||||||
"""Create a mock PubSub instance for testing."""
|
"""Create a mock PubSub instance for testing."""
|
||||||
@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon:
|
|||||||
return pubsub
|
return pubsub
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def subscription(self, subscription_params, mock_pubsub: MagicMock):
|
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||||
"""Create a subscription instance based on parameterized type."""
|
"""Create a subscription instance based on parameterized type."""
|
||||||
subscription_type, subscription_class = subscription_params
|
subscription_type, subscription_class = subscription_params
|
||||||
topic_name = f"test-{subscription_type}-topic"
|
topic_name = f"test-{subscription_type}-topic"
|
||||||
subscription = subscription_class(
|
subscription = subscription_class(
|
||||||
|
client=mock_redis_client,
|
||||||
pubsub=mock_pubsub,
|
pubsub=mock_pubsub,
|
||||||
topic=topic_name,
|
topic=topic_name,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user