[V1][Metrics] add support for kv event publishing (#16750)
Signed-off-by: alec-flowers <aflowers@nvidia.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
145
tests/distributed/conftest.py
Normal file
145
tests/distributed/conftest.py
Normal file
@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
from typing import Optional, Union
|
||||
|
||||
import msgspec
|
||||
import msgspec.msgpack
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVEventsConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory
|
||||
|
||||
from .test_events import SampleBatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_port():
|
||||
"""Generate a random port number for testing"""
|
||||
return random.randint(10000, 60000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher_config(random_port, request):
|
||||
"""Create a publisher config with inproc transport"""
|
||||
how = request.param if hasattr(request, "param") else "inproc"
|
||||
|
||||
if how == "inproc":
|
||||
endpoint = f"inproc://test-{random_port}"
|
||||
replay_endpoint = endpoint + "-replay"
|
||||
else:
|
||||
endpoint = f"tcp://*:{random_port}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 1}"
|
||||
|
||||
return KVEventsConfig(enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(publisher_config):
|
||||
"""Create and return a publisher instance"""
|
||||
pub = EventPublisherFactory.create(publisher_config)
|
||||
yield pub
|
||||
pub.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def subscriber(publisher_config):
|
||||
"""Create and return a subscriber for testing"""
|
||||
endpoint = publisher_config.endpoint
|
||||
replay_endpoint = publisher_config.replay_endpoint
|
||||
|
||||
if endpoint.startswith("tcp://*"):
|
||||
endpoint = endpoint.replace("*", "127.0.0.1")
|
||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||
|
||||
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
||||
yield sub
|
||||
sub.close()
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(self,
|
||||
pub_endpoint: str,
|
||||
replay_endpoint: Optional[str] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch):
|
||||
self.ctx = zmq.Context.instance()
|
||||
|
||||
# Set up subscriber socket
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
||||
self.sub.connect(pub_endpoint)
|
||||
|
||||
# Set up replay socket if provided
|
||||
self.replay = None
|
||||
if replay_endpoint:
|
||||
self.replay = self.ctx.socket(zmq.REQ)
|
||||
self.replay.connect(replay_endpoint)
|
||||
|
||||
self.topic = topic
|
||||
self.topic_bytes = topic.encode('utf-8')
|
||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
|
||||
def receive_one(self,
|
||||
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||
"""Receive a single message with timeout"""
|
||||
if not self.sub.poll(timeout):
|
||||
return None
|
||||
|
||||
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
|
||||
assert topic_bytes == self.topic_bytes
|
||||
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
self.last_seq = seq
|
||||
self.received_msgs.append((seq, data))
|
||||
return seq, data
|
||||
|
||||
def request_replay(self, start_seq: int) -> None:
|
||||
"""Request replay of messages starting from start_seq"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
|
||||
self.replay.send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
|
||||
replayed: list[tuple[int, SampleBatch]] = []
|
||||
while True:
|
||||
try:
|
||||
if not self.replay.poll(1000):
|
||||
break
|
||||
|
||||
frames = self.replay.recv_multipart()
|
||||
if not frames or not frames[-1]:
|
||||
# End of replay marker
|
||||
break
|
||||
|
||||
seq_bytes, payload = frames
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
replayed.append((seq, data))
|
||||
except zmq.ZMQError as _:
|
||||
break
|
||||
|
||||
return replayed
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
if self.replay:
|
||||
self.replay.close()
|
||||
193
tests/distributed/test_events.py
Normal file
193
tests/distributed/test_events.py
Normal file
@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import threading
|
||||
import time
|
||||
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||
NullEventPublisher)
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True # type: ignore
|
||||
):
|
||||
"""Test event for publisher testing"""
|
||||
id: int
|
||||
value: str
|
||||
|
||||
|
||||
class SampleBatch(EventBatch):
|
||||
"""Test event batch for publisher testing"""
|
||||
events: list[EventSample]
|
||||
|
||||
|
||||
def create_test_events(count: int) -> SampleBatch:
|
||||
"""Create a batch of test events"""
|
||||
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
|
||||
return SampleBatch(ts=time.time(), events=events)
|
||||
|
||||
|
||||
def test_basic_publishing(publisher, subscriber):
|
||||
"""Test basic event publishing works"""
|
||||
|
||||
test_batch = create_test_events(5)
|
||||
publisher.publish(test_batch)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert received.ts == pytest.approx(test_batch.ts,
|
||||
abs=0.1), ("Timestamp mismatch")
|
||||
assert len(received.events) == len(
|
||||
test_batch.events), ("Number of events mismatch")
|
||||
|
||||
for i, event in enumerate(received.events):
|
||||
assert event.id == i, "Event id mismatch"
|
||||
assert event.value == f"test-{i}", "Event value mismatch"
|
||||
|
||||
|
||||
def test_multiple_events(publisher, subscriber):
|
||||
"""Test publishing and receiving multiple event batches"""
|
||||
for _ in range(10):
|
||||
batch = create_test_events(2)
|
||||
publisher.publish(batch)
|
||||
|
||||
received = []
|
||||
for _ in range(10):
|
||||
data = subscriber.receive_one(timeout=100)
|
||||
if data:
|
||||
received.append(data)
|
||||
|
||||
assert len(received) == 10, "Number of messages mismatch"
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert seqs == list(range(10)), "Sequence numbers mismatch"
|
||||
|
||||
|
||||
def test_replay_mechanism(publisher, subscriber):
|
||||
"""Test the replay mechanism works correctly"""
|
||||
for _ in range(19):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(10)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch) # 20th message
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) > 0, "No replayed messages received"
|
||||
seqs = [seq for seq, _ in replayed]
|
||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||
assert seqs == list(range(min(seqs),
|
||||
max(seqs) +
|
||||
1)), ("Replayed messages not consecutive")
|
||||
|
||||
|
||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||
"""Test buffer limit behavior"""
|
||||
buffer_size = publisher_config.buffer_steps
|
||||
|
||||
# Publish more events than the buffer can hold
|
||||
for i in range(buffer_size + 10):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(0)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
|
||||
|
||||
oldest_seq = min(seq for seq, _ in replayed)
|
||||
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
|
||||
|
||||
|
||||
def test_topic_filtering(publisher_config):
|
||||
"""
|
||||
Test that a subscriber only receives messages matching its topic filter
|
||||
"""
|
||||
publisher_config.replay_endpoint = None
|
||||
|
||||
cfg = publisher_config.model_copy()
|
||||
cfg.topic = "foo"
|
||||
pub = EventPublisherFactory.create(cfg)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
|
||||
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
|
||||
|
||||
try:
|
||||
time.sleep(0.1)
|
||||
|
||||
for _ in range(3):
|
||||
pub.publish(create_test_events(1))
|
||||
|
||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is not None for msg in foo_received), (
|
||||
"Subscriber with matching topic should receive messages")
|
||||
|
||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is None for msg in bar_received), (
|
||||
"Subscriber with non-matching topic should receive no messages")
|
||||
finally:
|
||||
pub.shutdown()
|
||||
sub_foo.close()
|
||||
sub_bar.close()
|
||||
|
||||
|
||||
def test_high_volume(publisher, subscriber):
|
||||
"""Test publishing and receiving a high volume of events"""
|
||||
num_batches = 10_000
|
||||
events_per_batch = 100
|
||||
|
||||
# Publish events in a separate thread to not block
|
||||
def publish_events():
|
||||
for i in range(num_batches):
|
||||
batch = create_test_events(events_per_batch)
|
||||
publisher.publish(batch)
|
||||
# Small delay to avoid overwhelming
|
||||
if i % 100 == 0:
|
||||
time.sleep(0.01)
|
||||
|
||||
received: list[tuple[int, SampleBatch]] = []
|
||||
|
||||
publisher_thread = threading.Thread(target=publish_events)
|
||||
publisher_thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
while len(received) < num_batches:
|
||||
if time.time() - start_time > 10: # Timeout after 10 seconds
|
||||
break
|
||||
|
||||
result = subscriber.receive_one(timeout=100)
|
||||
if result:
|
||||
received.append(result)
|
||||
|
||||
publisher_thread.join()
|
||||
|
||||
assert len(received) >= num_batches * 0.9, (
|
||||
"We should have received most messages")
|
||||
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||
|
||||
|
||||
def test_null_publisher():
|
||||
"""Test that NullEventPublisher can be used without errors"""
|
||||
publisher = NullEventPublisher()
|
||||
|
||||
# This should not raise any errors
|
||||
batch = create_test_events(5)
|
||||
publisher.publish(batch)
|
||||
publisher.shutdown()
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
num_blocks=num_blocks,
|
||||
tensors={},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled():
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
def test_kv_cache_events(blocks_to_cache: int):
|
||||
block_size = 16
|
||||
num_blocks = blocks_to_cache + 1
|
||||
|
||||
# Allocate Blocks
|
||||
# Should see a single block stored event with a blocks_to_cache number of
|
||||
# block hashes
|
||||
# take_events should reset the kv_event_queue
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, num_blocks),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
_ = manager.allocate_slots(req0, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
block = events[-1]
|
||||
assert (len(block.block_hashes) == blocks_to_cache == len(
|
||||
manager.block_pool.cached_block_hash_to_block))
|
||||
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
|
||||
assert len(manager.block_pool.kv_event_queue) == 0
|
||||
|
||||
stored_block_hash = block.block_hashes
|
||||
|
||||
# Remove blocks and send another request
|
||||
# Should see block_to_cache number of removed block events and a new block
|
||||
# stored event
|
||||
manager.free(req0)
|
||||
req1 = make_request("1", list(range(num_tokens)))
|
||||
_ = manager.allocate_slots(req1, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
for blocks in events[:-1]:
|
||||
assert blocks.block_hashes[0] in stored_block_hash
|
||||
assert len(events) == blocks_to_cache + 1
|
||||
assert (isinstance(events[-2], BlockRemoved))
|
||||
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
|
||||
manager.block_pool.cached_block_hash_to_block))
|
||||
|
||||
# All Blocks Cleared
|
||||
# Should see a single all blocks cleared event
|
||||
manager.free(req1)
|
||||
manager.reset_prefix_cache()
|
||||
events = manager.take_events()
|
||||
|
||||
assert isinstance(events[-1], AllBlocksCleared)
|
||||
assert len(manager.block_pool.cached_block_hash_to_block) == 0
|
||||
|
||||
|
||||
def test_eagle_enabled_removes_last_block():
|
||||
"""Verify Eagle does NOT remove blocks when request
|
||||
length is divisible by block size."""
|
||||
|
||||
@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
|
||||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
||||
|
||||
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
try:
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
def test_kv_cache_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
time.sleep(0.1)
|
||||
subscriber = MockSubscriber(endpoint,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
request = EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=custom_tokens,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=1), # Short completion for speed
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event")
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), ("We should have a BlockStored event")
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes")
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size")
|
||||
assert event.parent_block_hash is None, (
|
||||
"Parent block hash should be None")
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
finally:
|
||||
client.shutdown()
|
||||
return
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
|
||||
Reference in New Issue
Block a user