[Spec Decode] (1/2) Remove batch expansion (#8839)
This commit is contained in:
@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
|
||||
65
tests/spec_decode/test_scorer.py
Normal file
65
tests/spec_decode/test_scorer.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import create_batch, create_worker
|
||||
|
||||
|
||||
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
|
||||
device: str) -> SpeculativeProposals:
|
||||
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
|
||||
device=device)
|
||||
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
|
||||
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
|
||||
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||
proposal_lens)
|
||||
|
||||
|
||||
def assert_score_equal(score1: SpeculativeScores,
|
||||
score2: SpeculativeScores) -> None:
|
||||
assert torch.allclose(score1.probs, score2.probs)
|
||||
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||
assert torch.equal(score1.token_ids, score2.token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize('propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
||||
device: str) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||
num_gpu_blocks, seed)
|
||||
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
scorer_worker.model_runner.model.sampler.\
|
||||
should_modify_greedy_probs_inplace = True
|
||||
|
||||
vocab_size = scorer_worker.vocab_size
|
||||
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=propose_len)
|
||||
|
||||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||
vocab_size)
|
||||
batch_expansion_score = batch_expansion_scorer.score_proposals(
|
||||
requests, proposals)
|
||||
|
||||
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
|
||||
mqa_score = mqa_scorer.score_proposals(requests, proposals)
|
||||
|
||||
assert_score_equal(batch_expansion_score, mqa_score)
|
||||
@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
def test_batch_expansion_correctly_calls_target_model(
|
||||
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
inputs with batch expansion. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
metrics_collector=metrics_collector,
|
||||
disable_mqa_scorer=True)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i: SequenceData.from_seqs(prompt_token_ids[:],
|
||||
cont_token_ids[:]),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
) for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations))
|
||||
]
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
|
||||
Reference in New Issue
Block a user