[BugFix] Fix tpu_model_runner block_id concatenation (#19228)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
new_block_ids=[[]],
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@ -460,8 +460,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
req_state.block_ids[i].extend(req_data.new_block_ids[i])
|
||||
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
|
||||
req_state.block_ids,
|
||||
req_data.new_block_ids,
|
||||
strict=True):
|
||||
block_ids.extend(new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
|
||||
@ -413,7 +413,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
|
||||
req_state.block_ids,
|
||||
req_data.new_block_ids,
|
||||
strict=True):
|
||||
block_ids.extend(new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
|
||||
Reference in New Issue
Block a user