[Bugfix][Nixl] Fix Preemption Bug (#18631)

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
Robert Shaw
2025-05-23 19:30:16 -04:00
committed by GitHub
parent 4fc1bf813a
commit 2b10ba7491
2 changed files with 97 additions and 15 deletions

View File

@ -340,3 +340,84 @@ def test_full_block_prompt():
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_cannot_schedule_after_recv():
"""
Test that we can handle no schedule after recv due to not
enough remaining KV blocks.
"""
# NOTE: the KVCacheManager will use 1 null block.
# So there are 5 total working blocks.
TOTAL_NUM_BLOCKS = 6
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
# Prime the KVCache.
NUM_PROMPT_BLOCKS = 2
BLOCK_SIZE = vllm_config.cache_config.block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 2: 5 blocks are in use (2 new for remote blocks).
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], finished_recving=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 4: try to schedule, not enough blocks.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 5: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Step 6: now we can schedule (with 2 blocks computed).
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 7: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@ -310,15 +310,16 @@ class Scheduler(SchedulerInterface):
break
request = self.waiting[0]
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs.
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
@ -349,8 +350,9 @@ class Scheduler(SchedulerInterface):
load_kv_async = False
# Get already-cached tokens.
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
@ -358,23 +360,22 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
@ -405,7 +406,7 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,