[V1][Metrics] add support for kv event publishing (#16750)

Signed-off-by: alec-flowers <aflowers@nvidia.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Alec
2025-04-30 16:44:45 +02:00
committed by GitHub
parent 77073c77bc
commit 0be6d05b5e
15 changed files with 1185 additions and 53 deletions

View File

@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional, Union
import msgspec
import msgspec.msgpack
import pytest
import zmq
from vllm.config import KVEventsConfig
from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch
@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)
@pytest.fixture
def publisher_config(random_port, request):
"""Create a publisher config with inproc transport"""
how = request.param if hasattr(request, "param") else "inproc"
if how == "inproc":
endpoint = f"inproc://test-{random_port}"
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"
return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")
@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
yield pub
pub.shutdown()
@pytest.fixture
def subscriber(publisher_config):
"""Create and return a subscriber for testing"""
endpoint = publisher_config.endpoint
replay_endpoint = publisher_config.replay_endpoint
if endpoint.startswith("tcp://*"):
endpoint = endpoint.replace("*", "127.0.0.1")
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
yield sub
sub.close()
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
self.ctx = zmq.Context.instance()
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)
# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)
self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
def receive_one(self,
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
"""Receive a single message with timeout"""
if not self.sub.poll(timeout):
return None
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
assert topic_bytes == self.topic_bytes
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
self.last_seq = seq
self.received_msgs.append((seq, data))
return seq, data
def request_replay(self, start_seq: int) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")
self.replay.send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")
replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
break
frames = self.replay.recv_multipart()
if not frames or not frames[-1]:
# End of replay marker
break
seq_bytes, payload = frames
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
replayed.append((seq, data))
except zmq.ZMQError as _:
break
return replayed
def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()

View File

@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
import threading
import time
import msgspec
import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
class EventSample(
msgspec.Struct,
tag=True, # type: ignore
array_like=True # type: ignore
):
"""Test event for publisher testing"""
id: int
value: str
class SampleBatch(EventBatch):
"""Test event batch for publisher testing"""
events: list[EventSample]
def create_test_events(count: int) -> SampleBatch:
"""Create a batch of test events"""
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
return SampleBatch(ts=time.time(), events=events)
def test_basic_publishing(publisher, subscriber):
"""Test basic event publishing works"""
test_batch = create_test_events(5)
publisher.publish(test_batch)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert received.ts == pytest.approx(test_batch.ts,
abs=0.1), ("Timestamp mismatch")
assert len(received.events) == len(
test_batch.events), ("Number of events mismatch")
for i, event in enumerate(received.events):
assert event.id == i, "Event id mismatch"
assert event.value == f"test-{i}", "Event value mismatch"
def test_multiple_events(publisher, subscriber):
"""Test publishing and receiving multiple event batches"""
for _ in range(10):
batch = create_test_events(2)
publisher.publish(batch)
received = []
for _ in range(10):
data = subscriber.receive_one(timeout=100)
if data:
received.append(data)
assert len(received) == 10, "Number of messages mismatch"
seqs = [seq for seq, _ in received]
assert seqs == list(range(10)), "Sequence numbers mismatch"
def test_replay_mechanism(publisher, subscriber):
"""Test the replay mechanism works correctly"""
for _ in range(19):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(10)
batch = create_test_events(1)
publisher.publish(batch) # 20th message
replayed = subscriber.receive_replay()
assert len(replayed) > 0, "No replayed messages received"
seqs = [seq for seq, _ in replayed]
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
assert seqs == list(range(min(seqs),
max(seqs) +
1)), ("Replayed messages not consecutive")
def test_buffer_limit(publisher, subscriber, publisher_config):
"""Test buffer limit behavior"""
buffer_size = publisher_config.buffer_steps
# Publish more events than the buffer can hold
for i in range(buffer_size + 10):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(0)
batch = create_test_events(1)
publisher.publish(batch)
replayed = subscriber.receive_replay()
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
oldest_seq = min(seq for seq, _ in replayed)
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
def test_topic_filtering(publisher_config):
"""
Test that a subscriber only receives messages matching its topic filter
"""
publisher_config.replay_endpoint = None
cfg = publisher_config.model_copy()
cfg.topic = "foo"
pub = EventPublisherFactory.create(cfg)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
try:
time.sleep(0.1)
for _ in range(3):
pub.publish(create_test_events(1))
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
assert all(msg is not None for msg in foo_received), (
"Subscriber with matching topic should receive messages")
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
assert all(msg is None for msg in bar_received), (
"Subscriber with non-matching topic should receive no messages")
finally:
pub.shutdown()
sub_foo.close()
sub_bar.close()
def test_high_volume(publisher, subscriber):
"""Test publishing and receiving a high volume of events"""
num_batches = 10_000
events_per_batch = 100
# Publish events in a separate thread to not block
def publish_events():
for i in range(num_batches):
batch = create_test_events(events_per_batch)
publisher.publish(batch)
# Small delay to avoid overwhelming
if i % 100 == 0:
time.sleep(0.01)
received: list[tuple[int, SampleBatch]] = []
publisher_thread = threading.Thread(target=publish_events)
publisher_thread.start()
start_time = time.time()
while len(received) < num_batches:
if time.time() - start_time > 10: # Timeout after 10 seconds
break
result = subscriber.receive_one(timeout=100)
if result:
received.append(result)
publisher_thread.join()
assert len(received) >= num_batches * 0.9, (
"We should have received most messages")
seqs = [seq for seq, _ in received]
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
def test_null_publisher():
"""Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher()
# This should not raise any errors
batch = create_test_events(5)
publisher.publish(batch)
publisher.shutdown()

View File

@ -6,6 +6,7 @@ from typing import Optional
import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
)
],
)
@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events(blocks_to_cache: int):
block_size = 16
num_blocks = blocks_to_cache + 1
# Allocate Blocks
# Should see a single block stored event with a blocks_to_cache number of
# block hashes
# take_events should reset the kv_event_queue
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)))
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
block = events[-1]
assert (len(block.block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
assert len(manager.block_pool.kv_event_queue) == 0
stored_block_hash = block.block_hashes
# Remove blocks and send another request
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)))
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
for blocks in events[:-1]:
assert blocks.block_hashes[0] in stored_block_hash
assert len(events) == blocks_to_cache + 1
assert (isinstance(events[-2], BlockRemoved))
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
# All Blocks Cleared
# Should see a single all blocks cleared event
manager.free(req1)
manager.reset_prefix_cache()
events = manager.take_events()
assert isinstance(events[-1], AllBlocksCleared)
assert len(manager.block_pool.cached_block_hash_to_block) == 0
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""

View File

@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from ...distributed.conftest import publisher_config, random_port # noqa: F401
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]

View File

@ -11,6 +11,7 @@ import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda():
@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
log_stats=True,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
try:
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
core_client: AsyncMPClient = client
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
with pytest.raises(Exception) as e_info:
await core_client.call_utility_async("echo", None, "help!")
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
assert str(e_info.value) == "Call to echo method failed: help!"
core_client: AsyncMPClient = client
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
await core_client.call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"
finally:
client.shutdown()
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
indirect=["publisher_config"],
)
def test_kv_cache_events(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
engine_args = EngineArgs(model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
time.sleep(0.1)
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
)
client.add_request(request)
outputs: dict[str, list] = {request.request_id: []}
loop_until_done(client, outputs)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event")
event = received.events[0]
assert isinstance(
event, BlockStored), ("We should have a BlockStored event")
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes")
assert event.block_size == block_size, (
"Block size should be the same as the block size")
assert event.parent_block_hash is None, (
"Parent block hash should be None")
assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens")
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens")
finally:
client.shutdown()
return
@pytest.mark.timeout(10)