feat: support redis xstream (#32586)

This commit is contained in:
wangxiaolei
2026-03-04 13:18:55 +08:00
committed by GitHub
parent e14b09d4db
commit 2f4c740d46
6 changed files with 558 additions and 21 deletions

View File

@ -0,0 +1,145 @@
import time
import pytest
from libs.broadcast_channel.redis.streams_channel import (
StreamsBroadcastChannel,
StreamsTopic,
_StreamsSubscription,
)
class FakeStreamsRedis:
"""Minimal in-memory Redis Streams stub for unit tests.
- Stores entries per key as [(id, {b"data": bytes}), ...]
- xadd appends entries and returns an auto-increment id like "1-0"
- xread returns entries strictly greater than last_id
- expire is recorded but has no effect on behavior
"""
def __init__(self) -> None:
self._store: dict[str, list[tuple[str, dict]]] = {}
self._next_id: dict[str, int] = {}
self._expire_calls: dict[str, int] = {}
# Publisher API
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
"""Append entry to stream; accept optional maxlen for API compatibility.
The test double ignores maxlen trimming semantics; only records the entry.
"""
n = self._next_id.get(key, 0) + 1
self._next_id[key] = n
entry_id = f"{n}-0"
self._store.setdefault(key, []).append((entry_id, fields))
return entry_id
def expire(self, key: str, seconds: int) -> None:
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
# Consumer API
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
# Expect a single key
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
entries = self._store.get(key, [])
# Find position strictly greater than last_id
start_idx = 0
if last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start_idx = i + 1
break
if start_idx >= len(entries):
# Simulate blocking wait (bounded) if requested
if block and block > 0:
time.sleep(min(0.01, block / 1000.0))
return []
end_idx = len(entries) if count is None else min(len(entries), start_idx + count)
batch = entries[start_idx:end_idx]
return [(key, batch)]
@pytest.fixture
def fake_redis() -> FakeStreamsRedis:
return FakeStreamsRedis()
@pytest.fixture
def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel:
return StreamsBroadcastChannel(fake_redis, retention_seconds=60)
class TestStreamsBroadcastChannel:
def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis):
topic = streams_channel.topic("alpha")
assert isinstance(topic, StreamsTopic)
assert topic._client is fake_redis
assert topic._topic == "alpha"
assert topic._key == "stream:alpha"
def test_publish_calls_xadd_and_expire(
self,
streams_channel: StreamsBroadcastChannel,
fake_redis: FakeStreamsRedis,
):
topic = streams_channel.topic("beta")
payload = b"hello"
topic.publish(payload)
# One entry stored under stream key (bytes key for payload field)
assert fake_redis._store["stream:beta"][0][1] == {b"data": payload}
# Expire called after publish
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
class TestStreamsSubscription:
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("gamma")
# Pre-publish events before subscribing (late subscriber)
topic.publish(b"e1")
topic.publish(b"e2")
sub = topic.subscribe()
assert isinstance(sub, _StreamsSubscription)
received: list[bytes] = []
with sub:
# Give listener thread a moment to xread
time.sleep(0.05)
# Drain using receive() to avoid indefinite iteration in tests
for _ in range(5):
msg = sub.receive(timeout=0.1)
if msg is None:
break
received.append(msg)
assert received == [b"e1", b"e2"]
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("delta")
sub = topic.subscribe()
with sub:
# No messages yet
assert sub.receive(timeout=0.05) is None
def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("epsilon")
sub = topic.subscribe()
with sub:
# Listener running; now close and ensure no crash
sub.close()
# After close, receive should raise SubscriptionClosedError
from libs.broadcast_channel.exc import SubscriptionClosedError
with pytest.raises(SubscriptionClosedError):
sub.receive()
def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis):
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0)
topic = channel.topic("zeta")
topic.publish(b"payload")
# No expire recorded when retention is disabled
assert fake_redis._expire_calls.get("stream:zeta") is None

View File

@ -0,0 +1,197 @@
import json
import uuid
from collections import defaultdict, deque
import pytest
from core.app.apps.message_generator import MessageGenerator
from models.model import AppMode
from services.app_generate_service import AppGenerateService
# -----------------------------
# Fakes for Redis Pub/Sub flow
# -----------------------------
class _FakePubSub:
def __init__(self, store: dict[str, deque[bytes]]):
self._store = store
self._subs: set[str] = set()
self._closed = False
def subscribe(self, topic: str) -> None:
self._subs.add(topic)
def unsubscribe(self, topic: str) -> None:
self._subs.discard(topic)
def close(self) -> None:
self._closed = True
def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1):
# simulate a non-blocking poll; return first available
if self._closed:
return None
for t in list(self._subs):
q = self._store.get(t)
if q and len(q) > 0:
payload = q.popleft()
return {"type": "message", "channel": t, "data": payload}
# no message
return None
class _FakeRedisClient:
def __init__(self, store: dict[str, deque[bytes]]):
self._store = store
def pubsub(self):
return _FakePubSub(self._store)
def publish(self, topic: str, payload: bytes) -> None:
self._store.setdefault(topic, deque()).append(payload)
# ------------------------------------
# Fakes for Redis Streams (XADD/XREAD)
# ------------------------------------
class _FakeStreams:
def __init__(self) -> None:
# key -> list[(id, {field: value})]
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
self._seq: dict[str, int] = defaultdict(int)
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
self._seq[key] += 1
eid = f"{self._seq[key]}-0"
self._data[key].append((eid, fields))
return eid
def expire(self, key: str, seconds: int) -> None:
# no-op for tests
return None
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
entries = self._data.get(key, [])
start = 0
if last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start = i + 1
break
if start >= len(entries):
return []
end = len(entries) if count is None else min(len(entries), start + count)
return [(key, entries[start:end])]
@pytest.fixture
def _patch_get_channel_streams(monkeypatch):
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
fake = _FakeStreams()
chan = StreamsBroadcastChannel(fake, retention_seconds=60)
def _get_channel():
return chan
# Patch both the source and the imported alias used by MessageGenerator
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
# Ensure AppGenerateService sees streams mode
import services.app_generate_service as ags
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False)
@pytest.fixture
def _patch_get_channel_pubsub(monkeypatch):
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
store: dict[str, deque[bytes]] = defaultdict(deque)
client = _FakeRedisClient(store)
chan = RedisBroadcastChannel(client)
def _get_channel():
return chan
# Patch both the source and the imported alias used by MessageGenerator
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
# Ensure AppGenerateService sees pubsub mode
import services.app_generate_service as ags
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False)
def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]):
# Publish events to the same topic used by MessageGenerator
topic = MessageGenerator.get_response_topic(app_mode, run_id)
for ev in events:
topic.publish(json.dumps(ev).encode())
@pytest.mark.usefixtures("_patch_get_channel_streams")
def test_streams_full_flow_prepublish_and_replay():
app_mode = AppMode.WORKFLOW
run_id = str(uuid.uuid4())
# Build start_task that publishes two events immediately
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
def start_task():
_publish_events(app_mode, run_id, events)
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
# Start retrieving BEFORE subscription is established; in streams mode, we also started immediately
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
received = []
for msg in gen:
if isinstance(msg, str):
# skip ping events
continue
received.append(msg)
if msg.get("event") == "workflow_finished":
break
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]
@pytest.mark.usefixtures("_patch_get_channel_pubsub")
def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch):
# Speed up any potential timer if it accidentally triggers
monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50)
app_mode = AppMode.WORKFLOW
run_id = str(uuid.uuid4())
published_order: list[str] = []
def start_task():
# When called (on subscribe), publish both events
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
_publish_events(app_mode, run_id, events)
published_order.extend([e["event"] for e in events])
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
# Producer not started yet; only when subscribe happens
assert published_order == []
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
received = []
for msg in gen:
if isinstance(msg, str):
continue
received.append(msg)
if msg.get("event") == "workflow_finished":
break
# Verify publish happened and consumer received in order
assert published_order == ["workflow_started", "workflow_finished"]
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]