[Core] Sliding window for block manager v2 (#4545)
Co-authored-by: Ruth Evans <ruthevans@Ruths-MacBook-Pro.local>
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
@ -40,3 +42,27 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
|
||||
def get_text_from_llm_generator(llm_generator: Iterable[LLM],
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb: Optional[Callable[[LLM],
|
||||
None]] = None):
|
||||
for llm in llm_generator:
|
||||
if llm_cb:
|
||||
llm_cb(llm)
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
text = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
|
||||
@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_token_ids_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
|
||||
168
tests/core/block/e2e/test_correctness_sliding_window.py
Normal file
168
tests/core/block/e2e/test_correctness_sliding_window.py
Normal file
@ -0,0 +1,168 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from .conftest import get_text_from_llm_generator
|
||||
|
||||
# relatively small model with 4k sliding window
|
||||
MODEL = "bigcode/starcoder2-3b"
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": MODEL,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"block_size": BLOCK_SIZE,
|
||||
# needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, seed):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
If we tell it upfront which we are going to be looking for, then
|
||||
it answers correctly (mostly).
|
||||
|
||||
Additionally, we compare the results of the v1 and v2 managers.
|
||||
"""
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1024,
|
||||
ignore_eos=True,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb=check_window(prompts))
|
||||
|
||||
check_answers(indices, answer, baseline_texts)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
cmp = [
|
||||
expected_text == actual_text
|
||||
for expected_text, actual_text in zip(baseline_texts, test_texts)
|
||||
]
|
||||
print(cmp)
|
||||
# make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768
|
||||
# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
|
||||
# states that xformers and flash_attn have different ideas about the window
|
||||
# size anyways
|
||||
assert sum(cmp) > 0.7 * len(cmp)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": MODEL,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"block_size": BLOCK_SIZE,
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"use_v2_block_manager": True,
|
||||
"enable_chunked_prefill": True
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
|
||||
"""
|
||||
This is similar to test_sliding_window_retrival, however, it doesn't
|
||||
compare against the v1 block manager since v1 doesn't support
|
||||
chunked prefill with sliding window.
|
||||
|
||||
The results with and without chunked prefill are not the same due to
|
||||
numerical instabilities.
|
||||
"""
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
ignore_eos=True,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
# We don't compare with the baseline model here, since the results
|
||||
# slightly different due to different tailing in attention.
|
||||
test_texts = get_text_from_llm_generator(test_llm_generator,
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb=check_window(prompts))
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
|
||||
def prep_prompts(batch_size: int):
|
||||
"""
|
||||
Generate prompts which a bunch of assignments,
|
||||
then asking for the value of one of them.
|
||||
The prompt is just under 10k tokens; sliding window is 4k
|
||||
so the answer is outside sliding window, but should still be correct.
|
||||
"""
|
||||
prompts: List[str] = []
|
||||
answer: List[int] = []
|
||||
indices: List[int] = []
|
||||
random.seed(1)
|
||||
for _ in range(batch_size):
|
||||
idx = random.randint(30, 90)
|
||||
indices.append(idx)
|
||||
prompt = "```python\n# We set a number of variables, " + \
|
||||
f"x{idx} will be important later\n"
|
||||
ln = random.randint(800, 1100)
|
||||
for k in range(30, ln):
|
||||
v = random.randint(10, 99)
|
||||
if k == idx:
|
||||
answer.append(v)
|
||||
prompt += f"x{k} = {v}\n"
|
||||
prompt += f"# Now, we check the value of x{idx}:\n"
|
||||
prompt += f"assert x{idx} == "
|
||||
prompts.append(prompt)
|
||||
return prompts, answer, indices
|
||||
|
||||
|
||||
def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
|
||||
answer2 = [int(text[0:2].strip()) for text in outputs]
|
||||
print(list(zip(indices, zip(answer, answer2))))
|
||||
numok = 0
|
||||
for a1, a2 in zip(answer, answer2):
|
||||
if a1 == a2:
|
||||
numok += 1
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok > 0.7
|
||||
|
||||
|
||||
def check_window(prompts: List[str]):
|
||||
|
||||
def inner(llm: LLM):
|
||||
sliding_window = llm.llm_engine.model_config.get_sliding_window()
|
||||
assert sliding_window and sliding_window > 0
|
||||
assert any(
|
||||
len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
|
||||
for prompt in prompts)
|
||||
|
||||
return inner
|
||||
@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
||||
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
|
||||
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
|
||||
assert num_consumed_blocks == expected_consumed_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8, 16])
|
||||
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
|
||||
@pytest.mark.parametrize("num_slots_to_append", [50])
|
||||
@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512])
|
||||
def test_sliding_window(block_size, prompt_len, num_slots_to_append,
|
||||
sliding_window):
|
||||
"""Verify append_slots consumes the correct number of blocks from the block
|
||||
table.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
watermark=watermark,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
def check_used(min_n, max_n=None):
|
||||
if max_n is None:
|
||||
max_n = min_n
|
||||
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
|
||||
#print("check", min_n, used, max_n)
|
||||
assert min_n <= used
|
||||
assert used <= max_n
|
||||
|
||||
def num_blocks(num_tokens):
|
||||
return (num_tokens + block_size - 1) // block_size
|
||||
|
||||
check_used(0)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=prompt_len,
|
||||
seq_output_lens=[0],
|
||||
)
|
||||
|
||||
check_used(0)
|
||||
|
||||
# Allocate seq
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
check_used(num_blocks(prompt_len))
|
||||
|
||||
# Seq seq to RUNNING
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
seq.data.update_num_computed_tokens(prompt_len)
|
||||
check_used(num_blocks(prompt_len))
|
||||
|
||||
# this is how we compute it in BlockSpaceManagerV2.__init__
|
||||
sliding_blocks = (sliding_window // block_size) + 2
|
||||
# plus one block for null block
|
||||
sliding_blocks += 1
|
||||
|
||||
# Append tokens to the sequeqnce
|
||||
for token_id in range(num_slots_to_append):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
block_manager.append_slots(seq, num_lookahead_slots=0)
|
||||
if prompt_len < sliding_window + 10:
|
||||
check_used(0, sliding_blocks + 1)
|
||||
else:
|
||||
check_used(sliding_blocks, sliding_blocks + 1)
|
||||
|
||||
Reference in New Issue
Block a user