[V1] Implement sliding window attention in kv_cache_manager (#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-01 15:33:17 +08:00
committed by GitHub
parent c7e63aa4d8
commit 3a5f0afcd2
15 changed files with 662 additions and 158 deletions

View File

@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers(indices, answer, test_texts)
def prep_prompts(batch_size: int):
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
"""
prompts: list[str] = []
answer: list[int] = []
@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
ln = random.randint(*ln_range)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
return prompts, answer, indices
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
def check_answers(indices: list[int],
answer: list[int],
outputs: list[str],
accept_rate: float = 0.7):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7
assert frac_ok >= accept_rate
def check_window(prompts: list[str]):

View File

@ -4,6 +4,7 @@
from typing import Optional
import pytest
import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
def make_request(request_id,
@ -39,13 +42,23 @@ def make_request(request_id,
)
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
@ -67,12 +80,12 @@ def test_prefill(hash_algo):
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@ -80,7 +93,7 @@ def test_prefill(hash_algo):
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -90,11 +103,11 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@ -107,14 +120,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
@ -122,11 +135,11 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
assert [b.block_id for b in blocks] == [8, 9]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -148,7 +161,7 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -162,10 +175,8 @@ def test_prefill_plp():
3. Schedule plp request; no hit should occur; validate blocks
'''
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -186,13 +197,13 @@ def test_prefill_plp():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@ -200,7 +211,7 @@ def test_prefill_plp():
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -211,11 +222,11 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@ -228,14 +239,14 @@ def test_prefill_plp():
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
@ -251,7 +262,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [0, 1, 2, 3, 4]
assert block_ids != [1, 2, 3, 4, 5]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@ -263,10 +274,8 @@ def test_prefill_plp():
def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -282,7 +291,7 @@ def test_decode():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -316,10 +325,8 @@ def test_decode():
def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -350,15 +357,15 @@ def test_evict():
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert [b.block_id for b in blocks] == [7, 6]
assert manager.block_pool.free_block_queue.num_free_blocks == 6
@ -369,10 +376,8 @@ def test_hash_block_correct_reuse():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
make_kv_cache_config(16, 2),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
make_kv_cache_config(block_size, 3),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0
assert blocks[0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
# Free the blocks.
manager.free(req0)
@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted():
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
assert computed_blocks[0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
def test_basic_prefix_caching_disabled():
@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled():
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
make_kv_cache_config(block_size, 5),
max_model_len=8192,
sliding_window=None,
enable_caching=False,
num_preallocate_tokens=0,
)
@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
@ -586,10 +585,8 @@ def test_mm_prefix_caching():
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -629,7 +626,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -736,7 +729,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -745,7 +738,7 @@ def test_reset_prefix_cache():
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
assert [b.block_id for b in blocks] == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()

View File

@ -2,12 +2,15 @@
from typing import Optional
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
@ -66,12 +69,21 @@ def create_scheduler(
model_config=model_config,
cache_config=cache_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(16, 1, 1, torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
return Scheduler(
scheduler_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)

View File

@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.core.specialized_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
def test_sliding_window_possible_cached_prefix():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
BlockHashType(i, ()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[block_hash] = {
i: block_pool.blocks[i + 10]
}
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
for block in computed_blocks[:expect_length - 2])
for i in range(2):
if i < expect_length:
block_index = expect_length - i - 1
assert computed_blocks[
block_index].block_id == block_index + 10
run_one_case([False] * 10, 0)
run_one_case([True], 1)
run_one_case([True, False], 1)
run_one_case([True, True], 2)
run_one_case([True, True, False], 2)
run_one_case([True, True, True], 3)
run_one_case([True, True, True, False], 3)
run_one_case([
True, True, False, True, False, False, True, True, False, True, True,
True
], 12)
run_one_case([
True, True, False, True, False, False, True, True, False, False, False
], 8)
run_one_case([
True, True, False, True, False, False, True, True, False, False, False,
True
], 8)
def test_sliding_window_remove_skipped_blocks():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids):
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table, ids):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block
else:
assert block.block_id == id_
original_block_ids = [
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
removed = manager.remove_skipped_blocks(block_table, 0)
assert_block_id(removed, [])
assert_block_id(block_table, original_block_ids)
# 4 tokens are computed. Only token 0 is out of the sliding window. As
# block 1000 also contains token 1 that is in the sliding window, block 1000
# cannot be removed.
removed = manager.remove_skipped_blocks(block_table, 4)
assert_block_id(removed, [])
assert_block_id(block_table, original_block_ids)
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
# Block 1000 can be removed.
removed = manager.remove_skipped_blocks(block_table, 5)
assert_block_id(removed, [original_block_ids[0]])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 6 tokens are computed. Token 0-2 are out of the sliding window.
# Cannot remove new block as the block 1001 is still used by token 3.
removed = manager.remove_skipped_blocks(block_table, 6)
assert_block_id(removed, [])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 7 tokens are computed. Token 0-3 are out of the sliding window.
# Block 1001 can be removed and block 1000 is already removed.
removed = manager.remove_skipped_blocks(block_table, 7)
assert_block_id(removed, [original_block_ids[1]])
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# 11 tokens are computed. Token 0-7 are out of the sliding window.
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
# sequence, and is expected to be evicted earlier than 1002, so the order
# of removed blocks should be [1003, 1002].
removed = manager.remove_skipped_blocks(block_table, 11)
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

View File

@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
from vllm import LLM, SamplingParams
from ...core.block.e2e.test_correctness_sliding_window import (check_answers,
prep_prompts)
@dataclass
class TestConfig:
sliding_window: int
ln_range: tuple[int, int]
model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
}
@pytest.mark.parametrize(
"model",
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-2-2b-it", # sliding window + full attention
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_config = model_config[model]
llm = LLM(model=model)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
prompts, answer, indices = prep_prompts(batch_size,
ln_range=test_config.ln_range)
check_length(prompts, llm, test_config.sliding_window)
# Fresh generation
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
# Re-generate with the same prompts to test prefix caching
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
"""
Check if the prompt length is valid, i.e., longer than the sliding window
size and shorter than the model's max length.
Args:
prompts: list of prompts
llm: LLM object
sliding_window: Sliding window size
"""
tokenizer = llm.get_tokenizer()
max_model_len = llm.llm_engine.model_config.max_model_len
assert any(
len(tokenizer.encode(prompt)) > sliding_window
for prompt in prompts), "Prompt is too short for test"
assert all(
len(tokenizer.encode(prompt)) <= max_model_len
for prompt in prompts), "Prompt is too long for test"