[Spec Decode] (1/2) Remove batch expansion (#8839)

This commit is contained in:
Lily Liu
2024-10-01 16:04:42 -07:00
committed by GitHub
parent 22f5851b80
commit 1570203864
29 changed files with 531 additions and 99 deletions

View File

@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
max_output_len=32,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(

View File

@ -0,0 +1,65 @@
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.worker.worker import Worker
from .utils import create_batch, create_worker
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)
def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
"""
seed = 0
block_size = 32
num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
batch_expansion_score = batch_expansion_scorer.score_proposals(
requests, proposals)
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
mqa_score = mqa_scorer.score_proposals(requests, proposals)
assert_score_equal(batch_expansion_score, mqa_score)

View File

@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()
vocab_size = 32_000

View File

@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
for i, final_len in enumerate(final_prompt_lens)
}
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations))
]
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def assert_logprobs_dict_allclose(