feat: add data parallel rank to KVEventBatch (#18925)

This commit is contained in:
Yan Ru Pei
2025-06-03 17:14:20 -07:00
committed by GitHub
parent a8da78eac9
commit b712be98c7
6 changed files with 359 additions and 83 deletions

View File

@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch
DP_RANK = 0
@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)
return random.randint(10000, 59900)
@pytest.fixture
@ -30,21 +32,23 @@ def publisher_config(random_port, request):
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"
replay_endpoint = f"tcp://*:{random_port + 100}"
return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")
return KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test",
)
@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
yield pub
pub.shutdown()
@ -60,7 +64,11 @@ def subscriber(publisher_config):
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
sub = MockSubscriber(
[endpoint],
[replay_endpoint] if replay_endpoint else None,
publisher_config.topic,
)
yield sub
sub.close()
@ -68,26 +76,37 @@ def subscriber(publisher_config):
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
def __init__(
self,
pub_endpoints: Union[str, list[str]],
replay_endpoints: Optional[Union[str, list[str]]] = None,
topic: str = "",
decode_type=SampleBatch,
):
self.ctx = zmq.Context.instance()
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)
# Convert single endpoint to list for consistency
if isinstance(pub_endpoints, str):
pub_endpoints = [pub_endpoints]
if isinstance(replay_endpoints, str):
replay_endpoints = [replay_endpoints]
# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)
# Set up subscriber socket - connect to all endpoints
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
for endpoint in pub_endpoints:
self.sub.connect(endpoint)
# Set up replay sockets if provided
self.replay_sockets = []
if replay_endpoints:
for replay_endpoint in replay_endpoints:
replay = self.ctx.socket(zmq.REQ)
replay.connect(replay_endpoint)
self.replay_sockets.append(replay)
self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.topic_bytes = topic.encode("utf-8")
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
@ -107,25 +126,31 @@ class MockSubscriber:
self.received_msgs.append((seq, data))
return seq, data
def request_replay(self, start_seq: int) -> None:
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")
if not self.replay_sockets:
raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
self.replay.send(start_seq.to_bytes(8, "big"))
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")
def receive_replay(self,
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages from a specific replay socket"""
if not self.replay_sockets:
raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
replay_socket = self.replay_sockets[socket_idx]
replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
if not replay_socket.poll(1000):
break
frames = self.replay.recv_multipart()
frames = replay_socket.recv_multipart()
if not frames or not frames[-1]:
# End of replay marker
break
@ -142,5 +167,5 @@ class MockSubscriber:
def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()
for replay in self.replay_sockets:
replay.close()

View File

@ -9,6 +9,8 @@ import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
DP_RANK = 0
class EventSample(
msgspec.Struct,
@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
publisher_config.replay_endpoint = None
publisher_config.topic = "foo"
pub = EventPublisherFactory.create(publisher_config)
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
def test_null_publisher():
"""Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher()
publisher = NullEventPublisher(DP_RANK)
# This should not raise any errors
batch = create_test_events(5)
publisher.publish(batch)
publisher.shutdown()
def test_data_parallel_rank_tagging(publisher_config):
"""Test that events are properly tagged with their data parallel rank"""
publisher_config.topic = "foo"
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
# Hardcode the expected endpoints based on port offsetting behavior
# Both ranks get offsets according to _offset_endpoint_port function
base_endpoint = publisher_config.endpoint
if "tcp://" in base_endpoint:
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
expected_endpoint_1 = base_endpoint.replace(
":5557", ":5558") # rank 1 gets port + 1
else:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0 = base_endpoint # rank 0 gets base
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
from .conftest import MockSubscriber
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
try:
time.sleep(0.1) # Let publishers start up
# Publish events from different ranks
batch_0 = create_test_events(2)
batch_1 = create_test_events(3)
pub_0.publish(batch_0)
pub_1.publish(batch_1)
# Receive events from rank 0
result_0 = sub_0.receive_one(timeout=200)
assert result_0 is not None, "No message received from rank 0"
seq_0, received_0 = result_0
# Receive events from rank 1
result_1 = sub_1.receive_one(timeout=200)
assert result_1 is not None, "No message received from rank 1"
seq_1, received_1 = result_1
# Verify DP rank tagging
assert received_0.data_parallel_rank == 0, (
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
assert received_1.data_parallel_rank == 1, (
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
# Verify event content is correct
assert len(
received_0.events) == 2, "Wrong number of events from rank 0"
assert len(
received_1.events) == 3, "Wrong number of events from rank 1"
finally:
pub_0.shutdown()
pub_1.shutdown()
sub_0.close()
sub_1.close()