[V1][Metrics] add support for kv event publishing (#16750)

Signed-off-by: alec-flowers <aflowers@nvidia.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Alec
2025-04-30 16:44:45 +02:00
committed by GitHub
parent 77073c77bc
commit 0be6d05b5e
15 changed files with 1185 additions and 53 deletions

View File

@ -0,0 +1,86 @@
#!/bin/bash
# This file demonstrates the KV cache event publishing
# We will launch a vllm instances configured to publish KV cache
# events and launch a simple subscriber to log those events.
set -xe
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
sleep 1
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
vllm serve $MODEL_NAME \
--port 8100 \
--max-model-len 100 \
--enforce-eager \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-events-config \
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
wait_for_server 8100
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
# Cleanup commands
pkill -9 -u "$USER" -f python
pkill -9 -u "$USER" -f vllm
sleep 1
echo "Cleaned up"
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
import msgspec
import zmq
from msgspec.msgpack import Decoder
#
# Types copied from vllm.distributed.kv_events
#
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
gc=False):
ts: float
events: list[Any]
class KVCacheEvent(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False,
tag=True):
"""Base class for all KV cache-related events"""
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
def process_event(event_batch):
print(f"Received event batch at {event_batch.ts}:")
for event in event_batch.events:
print(f" - {event}")
def main():
decoder = Decoder(type=KVEventBatch)
last_seq = -1
context = zmq.Context()
# Set up the main subscription socket
sub = context.socket(zmq.SUB)
sub.connect("tcp://localhost:5557")
topic = "kv-events"
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
# Initialize replay socket
replay = context.socket(zmq.REQ)
replay.connect("tcp://localhost:5558")
poller = zmq.Poller()
poller.register(replay, zmq.POLLIN)
print("Listening for KV cache events on topic:", topic)
while True:
try:
if sub.poll(50):
_, seq_bytes, payload = sub.recv_multipart()
seq = int.from_bytes(seq_bytes, "big")
if last_seq >= 0 and seq > last_seq + 1:
missed = seq - last_seq - 1
print(f"Missed {missed} messages"
f" (last: {last_seq}, current: {seq})")
replay.send((last_seq + 1).to_bytes(8, "big"))
while poller.poll(timeout=200):
seq_bytes, replay_payload = replay.recv_multipart()
if not replay_payload:
# End of replay marker is sent as an empty frame
# for the payload
break
replay_seq = int.from_bytes(seq_bytes, "big")
if replay_seq > last_seq:
event_batch = decoder.decode(replay_payload)
process_event(event_batch)
last_seq = replay_seq
if replay_seq >= seq - 1:
break
event_batch = decoder.decode(payload)
process_event(event_batch)
# ... do other periodic work or check for shutdown ...
except KeyboardInterrupt:
print("Interrupted")
break
except Exception as e:
print("Error decoding message:", e)
if __name__ == "__main__":
main()