[V0 Deprecation] Remove V0 Spec Decode workers (#21152)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -159,7 +159,6 @@ steps:
|
||||
- tests/distributed/test_utils
|
||||
- tests/distributed/test_pynccl
|
||||
- tests/distributed/test_events
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
@ -182,7 +181,6 @@ steps:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s distributed/test_events.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
@ -330,17 +328,6 @@ steps:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: Speculative decoding tests # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
- vllm/model_executor/models/eagle.py
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
@ -726,7 +713,6 @@ steps:
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -43,7 +43,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
|
||||
3
.github/mergify.yml
vendored
3
.github/mergify.yml
vendored
@ -164,10 +164,7 @@ pull_request_rules:
|
||||
description: Automatically apply speculative-decoding label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
|
||||
@ -73,7 +73,6 @@ line-length = 80
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
# Python 3.8 typing - skip utils for ROCm
|
||||
"vllm/utils/__init__.py" = ["UP006", "UP035"]
|
||||
|
||||
@ -6,7 +6,7 @@ import msgspec
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
from ..spec_decode.utils import create_batch
|
||||
from .utils import create_batch
|
||||
|
||||
|
||||
def test_msgspec_serialization():
|
||||
|
||||
@ -4,15 +4,16 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Any, Optional
|
||||
from itertools import count
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
|
||||
@ -262,3 +263,130 @@ class SchedulerProxy:
|
||||
self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
||||
_, _, ret = self.call_history["schedule"][-1]
|
||||
return ret
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
||||
@ -232,149 +229,6 @@ def test_engine_log_metrics_regression(
|
||||
assert_metrics(model, engine, disable_log_stats, len(example_prompts))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_metric_spec_decode(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
k = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4,
|
||||
speculative_config={
|
||||
"model": model,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
) as vllm_model:
|
||||
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
stat_logger.local_interval = 0
|
||||
|
||||
# Note that the purpose of this test is to verify spec decode
|
||||
# metrics instead of functional correctness, so the expected values
|
||||
# are intended to be loose.
|
||||
metric_name_to_expected_fn = {
|
||||
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
|
||||
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
|
||||
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
|
||||
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
|
||||
"counter_spec_decode_num_emitted_tokens":
|
||||
lambda v: 0 <= v <= k + 1,
|
||||
}
|
||||
|
||||
# Use one request to better inspect the metrics.
|
||||
prompts = example_prompts[:1]
|
||||
|
||||
_ = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
for metric_name, is_expected in metric_name_to_expected_fn.items():
|
||||
metric_val = getattr(
|
||||
stat_logger.metrics,
|
||||
metric_name).labels(**stat_logger.labels)._value.get()
|
||||
assert is_expected(metric_val), (
|
||||
f"the value of metric {metric_name} ({metric_val}) "
|
||||
"does not meet expectation")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
|
||||
def test_metric_spec_decode_interval(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
log_interval: int,
|
||||
) -> None:
|
||||
k = 5
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4,
|
||||
speculative_config={
|
||||
"model": model,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
try:
|
||||
|
||||
engine.add_request(
|
||||
"request-id-0",
|
||||
example_prompts[0],
|
||||
SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
|
||||
# set log internal
|
||||
stat_logger = engine.stat_loggers['prometheus']
|
||||
stat_logger.local_interval = log_interval
|
||||
|
||||
# prefill
|
||||
engine.step()
|
||||
|
||||
# wait for 5 seconds to ensure that spec decode metrics
|
||||
# get triggered in first decode step
|
||||
time.sleep(5)
|
||||
|
||||
# first decode step should trigger async collection of metrics
|
||||
engine.step()
|
||||
|
||||
# wait one second to allow H2D transfer to finish
|
||||
time.sleep(1)
|
||||
|
||||
# second decode step should now be able to collect the spec
|
||||
# decode stats and the request should also be finished
|
||||
engine.step()
|
||||
|
||||
# must have finisehd now
|
||||
assert not engine.has_unfinished_requests()
|
||||
|
||||
# wait to ensure logging occurs
|
||||
time.sleep(log_interval)
|
||||
|
||||
# force logging
|
||||
engine.step()
|
||||
|
||||
# Note that the purpose of this test is to verify spec decode
|
||||
# metrics instead of functional correctness, so the expected values
|
||||
# are intended to be loose.
|
||||
metric_name_to_expected_fn = {
|
||||
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
|
||||
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
|
||||
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
|
||||
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
|
||||
"counter_spec_decode_num_emitted_tokens":
|
||||
lambda v: 0 <= v <= k + 1,
|
||||
}
|
||||
|
||||
for metric_name, is_expected in metric_name_to_expected_fn.items():
|
||||
metric_val = getattr(
|
||||
stat_logger.metrics,
|
||||
metric_name).labels(**stat_logger.labels)._value.get()
|
||||
assert is_expected(metric_val), (
|
||||
f"the value of metric {metric_name} ({metric_val}) "
|
||||
"does not meet expectation")
|
||||
|
||||
finally:
|
||||
del engine
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool,
|
||||
num_requests: int) -> None:
|
||||
if disable_log_stats:
|
||||
|
||||
@ -457,12 +457,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
|
||||
|
||||
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
|
||||
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
|
||||
"MedusaModel": _HfExamplesInfo("JackFram/llama-68m",
|
||||
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
|
||||
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
||||
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
|
||||
# Temporarily disabled.
|
||||
# TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||
# "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
||||
# speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
|
||||
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
|
||||
@ -72,11 +72,15 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
|
||||
("MLPSpeculatorPreTrainedModel", False, False),
|
||||
("DeepseekV2ForCausalLM", True, False),
|
||||
("Qwen2VLForConditionalGeneration", True, True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch,is_pp,init_cuda",
|
||||
[
|
||||
# TODO(woosuk): Re-enable this once the MLP Speculator is supported
|
||||
# in V1.
|
||||
# ("MLPSpeculatorPreTrainedModel", False, False),
|
||||
("DeepseekV2ForCausalLM", True, False),
|
||||
("Qwen2VLForConditionalGeneration", True, True),
|
||||
])
|
||||
def test_registry_is_pp(model_arch, is_pp, init_cuda):
|
||||
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
|
||||
|
||||
|
||||
@ -1,577 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for rejection sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This file tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def mock_causal_accepted_tensor(
|
||||
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate an "accepted" tensor which should yield causally-accepted tokens
|
||||
up to last accepted indices.
|
||||
|
||||
Tokens after last_accepted_indices+1 may also be accepted, although they
|
||||
will not be causally accepted.
|
||||
"""
|
||||
batch_size = last_accepted_indices.shape[0]
|
||||
|
||||
accepted = (torch.arange(k).expand(batch_size, k)
|
||||
<= last_accepted_indices.unsqueeze(-1).broadcast_to(
|
||||
batch_size, k))
|
||||
|
||||
# Sprinkle accepted values after the contiguous initial accepted values.
|
||||
# This replicates the behavior of rejection sampling, which may "accept"
|
||||
# a token that cannot be accepted because of causality.
|
||||
sprinkle_candidates = (torch.arange(k).expand(
|
||||
batch_size,
|
||||
k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) +
|
||||
1)
|
||||
sprinkle = torch.rand(batch_size, k) > 0.5
|
||||
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
||||
return accepted
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 3000
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
last_accepted_indices = torch.randint(low=-1,
|
||||
high=k,
|
||||
size=(batch_size, ))
|
||||
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
recovered_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||
|
||||
# Expect all bonus tokens to be included.
|
||||
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
# Expect first token to be equal to recovered tokens.
|
||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||
|
||||
# Expect everything else to be -1.
|
||||
assert torch.equal(output_token_ids[:, 1:],
|
||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
recovered_plus_bonus = torch.cat(
|
||||
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||
# Assert first rejected token is a recovered token or bonus token.
|
||||
assert torch.equal(
|
||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1],
|
||||
output_token_ids[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1])
|
||||
|
||||
# Assert every subsequent token is -1.
|
||||
subsequent_mask = torch.arange(0, k + 1).expand(
|
||||
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
|
||||
assert torch.all(output_token_ids[subsequent_mask] == -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
}
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs))
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
for j in range(1, n_rep):
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
set_random_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
single_batches = []
|
||||
for i in range(batch_size):
|
||||
single_batches.append((draft_probs[i].clone().unsqueeze(0),
|
||||
draft_token_ids[i].clone().unsqueeze(0),
|
||||
target_probs[i].clone().unsqueeze(0),
|
||||
bonus_token_ids[i].clone().unsqueeze(0),
|
||||
draft_token_ids[i].clone().unsqueeze(0)))
|
||||
|
||||
set_random_seed(0)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
results = []
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(1, batch_size) # 0 is seed None
|
||||
}
|
||||
batch_result = rejection_sampler(target_probs.clone(),
|
||||
bonus_token_ids.clone(),
|
||||
draft_probs.clone(),
|
||||
draft_token_ids.clone(), seeded_seqs)
|
||||
|
||||
set_random_seed(0)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
for i in range(batch_size):
|
||||
request_seeded_seqs = {
|
||||
0: torch.Generator(device=device).manual_seed(i)
|
||||
} if seeded_seqs.get(i) is not None else None
|
||||
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
|
||||
draft_token_ids) = single_batches[i]
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, request_seeded_seqs))
|
||||
for i in range(batch_size):
|
||||
assert torch.equal(batch_result[i], results[i].squeeze(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
batch_size: int, device: str):
|
||||
"""
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
|
||||
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
|
||||
"the ability to pass in uniform samples.")
|
||||
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
num_accepted_tokens = []
|
||||
num_emitted_tokens = []
|
||||
num_draft_tokens = []
|
||||
|
||||
def get_seeded_seqs():
|
||||
return {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size)
|
||||
}
|
||||
|
||||
for use_flashinfer in [True, False]:
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
# We use seeded sequences to ensure the same tokens are accepted
|
||||
# for both flashinfer and nonflashinfer backends.
|
||||
seeded_seqs = get_seeded_seqs()
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
|
||||
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
|
||||
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
|
||||
|
||||
assert num_accepted_tokens[0] == num_accepted_tokens[1]
|
||||
assert num_emitted_tokens[0] == num_emitted_tokens[1]
|
||||
assert num_draft_tokens[0] == num_draft_tokens[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str,
|
||||
use_flashinfer: bool):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
|
||||
strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@pytest.mark.parametrize("seed", list(range(5)))
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampling_approximates_target_distribution(
|
||||
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
This is done by first creating a random target probability
|
||||
distribution and a random draft probability distribution. We then
|
||||
sample token ids from the rejection sampler using these draft
|
||||
and target distributions. The samples are used to estimate
|
||||
the output probability distribution, which we expect to approximate
|
||||
the target distribution.
|
||||
|
||||
A basic distance metric is used to determine similarity between
|
||||
distributions.
|
||||
|
||||
We expect that as we increase the number of samples,
|
||||
the distance between the observed distribution and the target
|
||||
distribution decreases. To measure this, we compare the distance
|
||||
of the observed distribution against both the target distribution
|
||||
and a uniform random distribution. We expect the distance between
|
||||
the observed distribution and the target distribution to improve
|
||||
much more than the distance improvement between the observed
|
||||
distribution and the random distribution.
|
||||
|
||||
When draft_and_target_probs_equal=True, the draft and target
|
||||
probabilities are exactly equal. Rejection sampling should
|
||||
still work without any NaNs or exceptions.
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
draft_and_target_probs_equal)
|
||||
|
||||
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
|
||||
distance_wrt_reference: list[float] = []
|
||||
distance_wrt_target: list[float] = []
|
||||
|
||||
for num_samples in sample_sizes:
|
||||
(reference_vs_rejsample_dist,
|
||||
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
|
||||
draft_probs,
|
||||
target_probs,
|
||||
reference_probs,
|
||||
num_samples,
|
||||
)
|
||||
|
||||
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
||||
distance_wrt_target.append(target_vs_rejsample_dist)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}")
|
||||
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}")
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
expected_improvement_multiplier = 20
|
||||
assert (relative_change_in_distance_wrt_target
|
||||
> relative_change_in_distance_wrt_reference *
|
||||
expected_improvement_multiplier)
|
||||
|
||||
|
||||
def get_ratio_first_to_last(elements: list[float]) -> float:
|
||||
return elements[0] / elements[-1]
|
||||
|
||||
|
||||
class _CorrectnessTestHelper:
|
||||
"""Class that packages together logic required for the unit-level
|
||||
rejection sampling correctness test.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
|
||||
self.rejection_sampler = rejection_sampler
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_range = (0, vocab_size)
|
||||
|
||||
self.rejection_sampler.init_gpu_tensors(device=0)
|
||||
|
||||
# Keep test simple, use k=1
|
||||
self.k = 1
|
||||
|
||||
# Bonus tokens not used, but rejection sampler requires
|
||||
# correct shape.
|
||||
self.num_bonus_tokens = 1
|
||||
|
||||
def generate_probs_for_test(
|
||||
self, draft_and_target_probs_equal: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
draft_probs, target_probs = (F.softmax(
|
||||
torch.rand(self.vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
) for _ in range(2))
|
||||
|
||||
num_reference_probs = 100
|
||||
reference_probs = F.softmax(
|
||||
torch.rand(num_reference_probs,
|
||||
self.vocab_size,
|
||||
dtype=torch.float32),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if draft_and_target_probs_equal:
|
||||
target_probs = draft_probs.clone()
|
||||
|
||||
return draft_probs, target_probs, reference_probs
|
||||
|
||||
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
reference_probs: torch.Tensor,
|
||||
num_samples: int) -> tuple[float, float]:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = self._estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_probs, num_samples)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = torch.dist(
|
||||
reference_probs,
|
||||
rej_sample_probs).item() / reference_probs.shape[0]
|
||||
target_vs_rejsample_dist = torch.dist(target_probs,
|
||||
rej_sample_probs).item()
|
||||
|
||||
return reference_vs_rejsample_dist, target_vs_rejsample_dist
|
||||
|
||||
def _estimate_rejection_sampling_pdf(
|
||||
self,
|
||||
draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
) -> torch.Tensor:
|
||||
# Repeat draft probs num_samples times.
|
||||
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
|
||||
num_samples, 1, 1)
|
||||
|
||||
# Repeat target probs num_samples * (k + 1) times.
|
||||
# Rejection sampler requires bonus token probs, but they aren't used.
|
||||
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
|
||||
num_samples, self.k + 1, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=1,
|
||||
replacement=True).reshape(
|
||||
num_samples, self.k)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
|
||||
dtype=torch.int64,
|
||||
device="cuda").repeat(num_samples, 1)
|
||||
|
||||
# Get output tokens via rejection sampling.
|
||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||
bonus_token_ids.to("cuda"),
|
||||
draft_probs.to("cuda"),
|
||||
draft_token_ids.to("cuda"))
|
||||
|
||||
# Remove bonus tokens
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
# Estimate probability density function
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
device="cpu"),
|
||||
bins=self.vocab_size,
|
||||
range=self.vocab_range,
|
||||
density=True)
|
||||
|
||||
return hist.hist
|
||||
@ -1,480 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for rejection sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This file tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
|
||||
"""
|
||||
Generates a fake temperature zero probability distribution.
|
||||
Returns:
|
||||
1. A fake temperature zero probability distribution of shape
|
||||
[batch_size, k, vocab_size]
|
||||
2. Tensor of shape [batch_size, k] containing the token ids
|
||||
of the probability 1.0 tokens at each position.
|
||||
"""
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
probs = torch.rand(batch_size, k, vocab_size)
|
||||
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
|
||||
# set the probability of the tokens with ids in zero_temperature_token_ids
|
||||
# to 1 and the rest to 0.
|
||||
target_probs = torch.zeros_like(probs).scatter_(
|
||||
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
|
||||
return target_probs, zero_temperature_token_ids
|
||||
|
||||
|
||||
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
||||
token_ids_to_exclude: torch.Tensor):
|
||||
"""
|
||||
Returns a tensor of shape [batch_size, k] of fake draft token ids
|
||||
drawn randomly from a vocab of size vocab_size. We however ensure
|
||||
that token_ids from token_ids_to_exclude are excluded at the
|
||||
corresponding positions.
|
||||
"""
|
||||
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
for j in range(k):
|
||||
# Generate a random token ID excluding token_ids_to_exclude[i, j]
|
||||
while True:
|
||||
token_id = torch.randint(0, vocab_size, (1, )).item()
|
||||
if token_id != token_ids_to_exclude[i, j]:
|
||||
draft_token_ids[i, j] = token_id
|
||||
break
|
||||
return draft_token_ids
|
||||
|
||||
|
||||
def get_acceptance_sampler(
|
||||
posterior_threshold: float = 0.03,
|
||||
posterior_alpha: float = 0.9,
|
||||
strict_mode: bool = False,
|
||||
) -> TypicalAcceptanceSampler:
|
||||
"""
|
||||
Initializes and returns a TypicalAcceptanceSampler.
|
||||
"""
|
||||
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||
strict_mode)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
"""
|
||||
Tests that the TypicalAcceptancSampler forward succeeds for
|
||||
different combinations of k, vocab_size, batch_size and num devices.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
"""
|
||||
Tests that we throw an exception of the token ids fall outside
|
||||
the bound of the provided vocabulary.
|
||||
"""
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that appropriate exceptions are thrown for out
|
||||
# of bound vocabs.
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_uniform_target_distribution_accepts_all_tokens(
|
||||
seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a uniform target probability
|
||||
distribution.
|
||||
|
||||
This test verifies that when provided with a uniform target probability
|
||||
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
|
||||
entropy of the uniform target distribution being high should lead to all
|
||||
draft tokens being accepted.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# We are using a uniform target probability distribution.
|
||||
# For a uniform distribution the entropy is very high and it
|
||||
# should lead to all draft tokens being accepted. Verify that.
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
|
||||
|
||||
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_temperature_zero_target_distribution(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a zero-temperature target
|
||||
probability distribution.
|
||||
|
||||
This test verifies that when using a zero-temperature target probability
|
||||
distribution, where only one token has a probability of 1.0, the
|
||||
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
|
||||
match this probability. Additionally, it ensures that when all draft
|
||||
tokens are rejected, the sampler falls back to greedy sampling to select a
|
||||
single token from the target distribution.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
# Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
# The target probaility distribution is a temperature zero distribution
|
||||
# with zero entropy. Since our draft token ids don't match the probability
|
||||
# 1.0 tokens in the target distribution we will reject all of them and
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
|
||||
0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_mixed_target_distribution(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a mixed target probability
|
||||
distribution.
|
||||
|
||||
This test ensures that the TypicalAcceptanceSampler handles a mixed
|
||||
target probability distribution correctly. Specifically, it uses a
|
||||
zero-temperature distribution for some sequences and a uniform
|
||||
distribution for others. The test verifies that:
|
||||
|
||||
- For sequences with a zero-temperature distribution, only the token
|
||||
with a probability of 1.0 is accepted, and all other tokens are rejected.
|
||||
- For sequences with a uniform distribution, all draft tokens are
|
||||
accepted.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 4
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
|
||||
target_probs[[1, 3]] = uniform_probs
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# verify the shape of output_token_ids
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
# For sequences 0 and 2 verify that only 1 token is accepted
|
||||
# which is the token with probability 1.0 in the target distribution
|
||||
# at position 0.
|
||||
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
|
||||
assert (torch.all(output_token_ids[[0, 2],
|
||||
0] == zero_temperature_token_ids[[0, 2],
|
||||
0]))
|
||||
# For sequences 1 and 3 verify that all tokens are accepted since the
|
||||
# target probability distribution is uniform. In addition verify that
|
||||
# we also accept the bonus tokens.
|
||||
assert torch.all(
|
||||
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
|
||||
assert torch.all(output_token_ids[[1, 3], -1] != -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_partially(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
|
||||
tokens should be accepted.
|
||||
|
||||
This test verifies that the TypicalAcceptanceSampler correctly accepts or
|
||||
rejects draft tokens based on a zero-temperature target probability
|
||||
distribution. Specifically, it ensures that:
|
||||
|
||||
- When all draft tokens match tokens with a probability of 1.0 in the
|
||||
target distribution, all draft tokens are accepted.
|
||||
- When only some draft tokens match tokens with a probability of 1.0 in
|
||||
the target distribution, only those matching tokens are accepted, and the
|
||||
rest are rejected.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
draft_token_ids = zero_temperature_token_ids
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
# Next only keep the first 2 draft tokens same as the zero temperature
|
||||
# tokens. For the remaining 3 choose some other tokens. In the
|
||||
# response we will expect the first 2 tokens to be the same as the
|
||||
# draft tokens and the recovered token and rest as -1
|
||||
draft_token_ids_to_replace = get_draft_token_ids(
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
assert torch.all(
|
||||
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
|
||||
assert torch.all(output_token_ids[:, -3:] == -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(1)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with custom posterior thresholds and
|
||||
alpha values. This test verifies that by modifying the posterior
|
||||
thresholds and alpha values we can change the acceptance behavior of the
|
||||
sampler.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
# probabilities and create target probabilities such that only 1 token
|
||||
# id has probability 1.0 and others have a very low probability of
|
||||
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0. Without any changes to the posterior thresholds
|
||||
# none of the draft tokens are accepted.
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs[target_probs == 0] = 0.00001
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||
|
||||
# Change the posterior threshold values to 0.0 so that we will
|
||||
# now accept even draft tokens with very low probability in the
|
||||
# target distribution. Simulate and verify the same.
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_get_recovered_token_ids(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's method for generating
|
||||
replacement token IDs.
|
||||
|
||||
This test verifies that the `_get_recovered_token_ids` method of the
|
||||
TypicalAcceptanceSampler correctly identifies the token IDs to be used
|
||||
as recovered token IDs based on the target probability distribution.
|
||||
Specifically, it ensures that the method correctly identifies the
|
||||
tokens with the highest probability for each sequence in the batch.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 10
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
|
||||
actual_replacement_tokens = (
|
||||
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
|
||||
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
|
||||
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,307 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from itertools import cycle
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from ...models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
|
||||
def generate():
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
return generate
|
||||
|
||||
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.method == "ngram"):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||
NGramWorker)
|
||||
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> tuple[list[str], list[list[int]], float]:
|
||||
tokens: list[str] = []
|
||||
token_ids: list[list[int]] = []
|
||||
acceptance_rate: float = -1.0
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
||||
# Fetch acceptance rate if logging is enabled.
|
||||
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
|
||||
stat_logger = stat_loggers["prometheus"]
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
del llm
|
||||
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def check_logprobs_correctness(
|
||||
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
disable_logprobs: bool = False,
|
||||
):
|
||||
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||
"""
|
||||
if not disable_logprobs:
|
||||
return check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=spec_outputs,
|
||||
name_0="org",
|
||||
name_1="sd",
|
||||
)
|
||||
|
||||
# Check correctness when disable_logprobs == True
|
||||
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||
# Check generated token logprobs.
|
||||
spec_logprobs = spec_output[2]
|
||||
baseline_logprobs = baseline_output[2]
|
||||
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||
baseline_logprobs,
|
||||
is_prompt_logprobs=False)
|
||||
|
||||
# Check prompt logprobs too, if they exist
|
||||
if len(baseline_output) == 4:
|
||||
assert len(spec_output) == 4
|
||||
spec_prompt_logprobs = spec_output[3]
|
||||
baseline_prompt_logprobs = baseline_output[3]
|
||||
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||
baseline_prompt_logprobs,
|
||||
is_prompt_logprobs=True)
|
||||
|
||||
|
||||
def _check_logprobs_when_output_disabled(
|
||||
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
is_prompt_logprobs: bool = False,
|
||||
):
|
||||
# Prompt logprobs are optional
|
||||
if is_prompt_logprobs and baseline_logprobs is None:
|
||||
assert spec_logprobs is None
|
||||
return
|
||||
|
||||
assert spec_logprobs is not None
|
||||
assert baseline_logprobs is not None
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# First prompt logprob is expected to be None
|
||||
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||
assert spec_pos_logprobs is None
|
||||
assert pos == 0
|
||||
continue
|
||||
|
||||
assert spec_pos_logprobs is not None
|
||||
assert baseline_pos_logprobs is not None
|
||||
|
||||
# When disabled, the 1 logprob is returned with dummy values for the
|
||||
# score and rank, but the token id should match the baseline model
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
(spec_pos_logprob_token_id,
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
disable_logprobs: bool = False):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
if disable_seed:
|
||||
seed = None
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert True
|
||||
# FIXME: ci fails to log acceptance rate.
|
||||
# It works locally.
|
||||
# assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
# Only pass token entries, not the logprobs
|
||||
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
# Check logprobs if requested
|
||||
if logprobs is not None or prompt_logprobs is not None:
|
||||
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||
|
||||
|
||||
def run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: int = 0,
|
||||
temperature: float = 0.0,
|
||||
logprobs: Optional[int] = None):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
|
||||
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
|
||||
env1 = env2 = None
|
||||
|
||||
max_wait_seconds = 240
|
||||
results = []
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
with RemoteOpenAIServer(model,
|
||||
args,
|
||||
env_dict=env,
|
||||
max_wait_seconds=max_wait_seconds) as server:
|
||||
client = server.get_client()
|
||||
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompts,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
logprobs=logprobs)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"logprobs": [choice.logprobs for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
completion.usage,
|
||||
})
|
||||
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
arg2_results = results[n:]
|
||||
# Separate logprobs to avoid asserting exact equality.
|
||||
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
|
||||
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
|
||||
|
||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||
assert arg1_result == arg2_result, (
|
||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||
f"{arg1_result=} != {arg2_result=}")
|
||||
if logprobs:
|
||||
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
|
||||
for l1, l2 in zip(logs1, logs2):
|
||||
assert l1.tokens == l2.tokens
|
||||
@ -1,66 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Speculative max model len > overridden max model len should raise.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 129,
|
||||
},
|
||||
"max_model_len": 128,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > draft max model len should raise.
|
||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 2048 + 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
# Speculative max model len > target max model len should raise.
|
||||
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 131072 + 1,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||
"""Verify that speculative decoding validates speculative_max_model_len.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
@ -1,480 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, EAGLE would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 4
|
||||
|
||||
# precision
|
||||
PRECISION = "float32"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# 2 for small prompt, 256//16 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 16,
|
||||
"max_model_len": (2 + 256 // 16) * 16,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# 2 for small prompt, 256//16 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 16,
|
||||
"max_model_len": (2 + 256 // 16) * 16,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "Qwen/Qwen2-7B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@ -1,161 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests which cover integration of the speculative decoding framework with
|
||||
other features, e.g. cuda graphs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify spec decode equality when cuda graphs are enabled.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
# Explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "gptq",
|
||||
},
|
||||
},
|
||||
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "marlin",
|
||||
},
|
||||
},
|
||||
# Not explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": None,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with draft model quantization configs.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@ -1,247 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests which cover integration of the speculative decoding framework with
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .conftest import run_equality_correctness_test_tp
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor-parallel-size",
|
||||
"2"
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
[
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
],
|
||||
[
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify greedy equality when tensor parallelism is used.
|
||||
"""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("hip is not well-supported yet")
|
||||
run_equality_correctness_test_tp("JackFram/llama-68m",
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
# precision
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
"model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "ibm-granite/granite-3b-code-instruct",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
seed: int):
|
||||
"""Verify spec decode works well with smaller tp for draft models.
|
||||
"""
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
# precision
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
]])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[["--enable-chunked-prefill", "False"],
|
||||
[
|
||||
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
|
||||
"--max-num-seqs", "4"
|
||||
]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
]),
|
||||
("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
logprobs: Optional[int],
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
# precision
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
]])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[["--enable-chunked-prefill", "False"],
|
||||
[
|
||||
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
|
||||
"--max-num-seqs", "4"
|
||||
]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
}),
|
||||
]),
|
||||
("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
"disable_logprobs": False,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2_with_logprobs(
|
||||
model, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int],
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
@ -1,123 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests which cover integration of the speculative decoding framework with
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from .conftest import run_equality_correctness_test_tp
|
||||
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
SPEC_MODEL = "JackFram/llama-68m"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce_eager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
[],
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||
[
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
seed: int):
|
||||
"""Verify spec decode works well with smaller tp for draft models.
|
||||
"""
|
||||
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
[
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"--speculative_config",
|
||||
json.dumps({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# This must be a good bit larger than speculative_max_model_len so that
|
||||
# we can test the case where all seqs are skipped, but still small to
|
||||
# ensure fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify job failure with RuntimeError when all sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
|
||||
TODO: fix it to pass without raising Error. (#5814)
|
||||
"""
|
||||
with pytest.raises(
|
||||
(openai.APIConnectionError, openai.InternalServerError)):
|
||||
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
@ -1,315 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
|
||||
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding,
|
||||
as well as with and without chunked prefill.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int, logprobs: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
# Artificially limit the draft model max model len; this forces
|
||||
# vLLM to skip speculation once the sequences grow beyond 32-k
|
||||
# tokens.
|
||||
"max_model_len": 32,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1])
|
||||
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [6])
|
||||
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
||||
case where the sampled token is not in top-k logprobs.
|
||||
|
||||
Ideally, this test should validate equality with non-spec by getting
|
||||
logprobs. This is left as future improvement.
|
||||
"""
|
||||
temperature = 1.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
num_returned_logprobs = [
|
||||
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
|
||||
]
|
||||
|
||||
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
||||
# sampled token is not in top-k).
|
||||
assert any(
|
||||
[num_returned > logprobs for num_returned in num_returned_logprobs])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("logprobs", [0])
|
||||
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Check the behavior when logprobs are disabled.
|
||||
Token choices should match with the base model.
|
||||
"""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
@ -1,417 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, Medusa would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||
# OOM in CI pipeline, so using a smaller model.
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||
|
||||
# max number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 5
|
||||
|
||||
# precision
|
||||
PRECISION = "float32"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int, prefill_chunk_size: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@ -1,533 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, MLPSpeculator would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-160m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# n_predict in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 3
|
||||
|
||||
# precision
|
||||
PRECISION = "float32"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0.0,
|
||||
seed=seed,
|
||||
expected_acceptance_rate=0.48)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative config
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed,
|
||||
disable_seed=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
return pad_vocab_size(vocab_size, pad_to=32064)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
||||
patched_pad_vocab_size):
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, prefill_chunk_size: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@ -1,333 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, mtp would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 1
|
||||
|
||||
# precision
|
||||
PRECISION = "bfloat16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@ -1,842 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The tests in this file verify end-to-end speculative decoding correctness.
|
||||
|
||||
This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
|
||||
highest probability in the target distribution are accepted. Therefore, we can
|
||||
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model). Similarly, for the
|
||||
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
|
||||
test cases.
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
same input. vLLM largely guarantees this.
|
||||
|
||||
@cadedaniel has seen cases where the output probabilities of a draft/target
|
||||
model change slightly with certain batch sizes or prompts, even with Torch
|
||||
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
|
||||
determinism in on-device batched operations, a bug in vLLM's spec decode
|
||||
implementation, or the "hardware numerics" limitations. Either way, rejection
|
||||
sampling ensures the output distribution matches the target model, but it breaks
|
||||
greedy-equality tests for those batch sizes/prompts.
|
||||
"""
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
from .conftest import (get_output_from_llm_generator,
|
||||
run_equality_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
# Chunked prefill enabled with small value
|
||||
# to make sure we get mixed batches.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
{
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
# decode is disabled.
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||
detokenization matches HF transformers.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
# Expect a generation for each prompt in the batch.
|
||||
assert len(batch_token_ids) == len(prompts)
|
||||
|
||||
# Expect each generation to have expected number of tokens (note ignore_eos
|
||||
# is True).
|
||||
assert [len(token_ids)
|
||||
for token_ids in batch_token_ids] == ([output_len] * batch_size)
|
||||
|
||||
# Expect detokenized string to match.
|
||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
|
||||
expected_tokens = tok.decode(actual_token_ids)
|
||||
print(f"{actual_token_ids=}")
|
||||
assert actual_tokens.strip() == expected_tokens.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use long output len for the small model test.
|
||||
10,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a tiny model with batch size of one.
|
||||
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
with a higher output_len.
|
||||
|
||||
When the draft model is the same as the target model, we further check
|
||||
whether all speculative tokens are accepted.
|
||||
"""
|
||||
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
prompt_logprobs=2,
|
||||
logprobs=2,
|
||||
disable_logprobs=False,
|
||||
temperature=0.0,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a tiny model and large batch size.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
max_output_len: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||
sampling respects the EOS token.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use decently long output len for a high quality test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||
separate from large BS tests to make identifying the source of bugs easier.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||
This is the closest test to a real production workload.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# https://github.com/triton-lang/triton/issues/2266 tl.dot
|
||||
# doesn't support embedding < 16
|
||||
{
|
||||
"block_size": 16,
|
||||
},
|
||||
{
|
||||
"block_size": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality over different block sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# This must be a good bit larger than speculative_max_model_len so that
|
||||
# we can test the case where all seqs are skipped, but still small to
|
||||
# ensure fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality when all sequences disable speculation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
] + [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with many different values of k.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": False
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
] + [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
} for k in [1, 2, 3]])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@create_new_process_for_each_test()
|
||||
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||
sampling method.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@ -1,392 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0,
|
||||
seed=seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 1,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@ -1,70 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "JackFram/llama-160m"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# speculative config
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
20,
|
||||
])
|
||||
def test_seeded_consistency(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
temperature: float, output_len: int):
|
||||
"""Verify outputs are consistent across multiple runs with same seed
|
||||
"""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
disable_seed=False,
|
||||
)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
disable_seed=True,
|
||||
)
|
||||
@ -1,110 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
|
||||
from .utils import create_seq_group_metadata_from_prompts, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_target_seq_ids', [100])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
|
||||
"""Verify all new sequence ids are greater than all input
|
||||
seq ids.
|
||||
"""
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
|
||||
all_seq_ids = [
|
||||
[1, 3, 5, 7],
|
||||
list(range(100)) + [0],
|
||||
[100],
|
||||
]
|
||||
|
||||
for seq_ids in all_seq_ids:
|
||||
max_seq_id = max(seq_ids)
|
||||
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
|
||||
for _ in range(num_target_seq_ids):
|
||||
assert next(iterator) > max_seq_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_get_token_ids_to_score(k: int):
|
||||
"""Verify correct tokens are selected for scoring.
|
||||
"""
|
||||
proposal_token_ids = torch.tensor(
|
||||
list(range(k)),
|
||||
dtype=torch.int64,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
expected_output: list[list[int]] = [
|
||||
[],
|
||||
]
|
||||
for i in range(proposal_token_ids.shape[0]):
|
||||
expected_output.append(proposal_token_ids[:i + 1].tolist())
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
|
||||
|
||||
actual_output = [
|
||||
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
|
||||
]
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_create_single_target_seq_group_metadata(k: int):
|
||||
"""Verify correct creation of a batch-expanded seq group metadata.
|
||||
"""
|
||||
|
||||
prompt_tokens = [1, 2, 3]
|
||||
prev_output_tokens = [4, 5, 6]
|
||||
|
||||
token_ids = list(range(k))
|
||||
|
||||
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
|
||||
|
||||
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
|
||||
token_ids)
|
||||
|
||||
block_size = 32
|
||||
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
|
||||
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
|
||||
[prev_output_tokens], [num_tokens_processed])[0]
|
||||
|
||||
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
|
||||
target_seq_id = 100
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
target_seq_id,
|
||||
token_ids,
|
||||
input_seq_group_metadata.sampling_params,
|
||||
)
|
||||
|
||||
assert output.request_id == input_seq_group_metadata.request_id
|
||||
assert output.sampling_params.repetition_penalty == \
|
||||
input_seq_group_metadata.sampling_params.repetition_penalty
|
||||
assert output.sampling_params.temperature == \
|
||||
input_seq_group_metadata.sampling_params.temperature
|
||||
assert output.sampling_params.top_p == \
|
||||
input_seq_group_metadata.sampling_params.top_p
|
||||
assert output.sampling_params.top_k == \
|
||||
input_seq_group_metadata.sampling_params.top_k
|
||||
assert len(output.seq_data) == 1
|
||||
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
|
||||
prompt_tokens)
|
||||
assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
|
||||
prev_output_tokens + token_ids)
|
||||
|
||||
assert len(output.block_tables) == 1
|
||||
assert output.block_tables[
|
||||
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
|
||||
@ -1,90 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('k', [1])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
running_queue_size=queue_size)
|
||||
|
||||
if queue_size > disable_by_batch_size:
|
||||
with patch.object(worker,
|
||||
'_run_no_spec',
|
||||
side_effect=ValueError(exception_secret)), \
|
||||
pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# When the batch size is larger than the threshold,
|
||||
# we expect no speculative tokens (0).
|
||||
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
|
||||
assert seq_group_metadata_list[
|
||||
0].num_speculative_tokens == expected_num_spec_tokens
|
||||
|
||||
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device='cpu', # not used
|
||||
vocab_size=100, # not used
|
||||
# Must be long enough to avoid being skipped due to length.
|
||||
max_proposal_len=1024,
|
||||
)
|
||||
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||
@ -1,91 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
This test verifies that memory usage remains constant (or never grows) when
|
||||
we enable / disable speculation via --speculative-disable-by-batch-size.
|
||||
|
||||
There are a lot of things we try to keep track of between batches of requests
|
||||
and if certain tensors are not freed from memory, can result in CUDA ooms.
|
||||
|
||||
This is particularly relevant for production situations where speculation might
|
||||
be enabled during off hours, but disabled once traffic peaks during the workday.
|
||||
Since traffic will stay high for a long period of time, verifying we do not
|
||||
increase our memory usage over time is essential to prevent possible CUDA ooms.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from tests.core.utils import create_dummy_prompt
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
ITERATIONS = 100
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||
|
||||
BATCH_SIZE = 5
|
||||
SPEC_DISABLE_BATCH_SIZE = 2
|
||||
|
||||
|
||||
def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup):
|
||||
scheduler = engine.scheduler[0]
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
|
||||
"""
|
||||
Since we are using a batch size greater than the disabled batch size,
|
||||
we can ensure we go through the _no_spec codepath for most of our engine steps.
|
||||
"""
|
||||
|
||||
|
||||
def test_memory_usage_no_spec():
|
||||
previous_memory_allocated = None
|
||||
llm = vllm.LLM(model=MAIN_MODEL,
|
||||
speculative_config={
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_by_batch_size": SPEC_DISABLE_BATCH_SIZE,
|
||||
})
|
||||
|
||||
batch_sequences = set()
|
||||
engine = llm.llm_engine
|
||||
|
||||
for i in range(ITERATIONS):
|
||||
seq, seq_group = create_dummy_prompt(request_id=str(i),
|
||||
prompt_length=10,
|
||||
min_tokens=10,
|
||||
max_tokens=10)
|
||||
|
||||
add_seq_group_to_engine(engine, seq_group)
|
||||
|
||||
batch_sequences.add(seq)
|
||||
engine.step()
|
||||
for seq in list(batch_sequences):
|
||||
if seq.is_finished():
|
||||
batch_sequences.remove(seq)
|
||||
|
||||
# If we aren't at our batch size yet, continue
|
||||
if len(batch_sequences) <= BATCH_SIZE:
|
||||
continue
|
||||
|
||||
# Otherwise, loop until at least one request is done
|
||||
while not any(seq.is_finished() for seq in batch_sequences):
|
||||
engine.step()
|
||||
|
||||
# Remove it from the set
|
||||
for seq in list(batch_sequences):
|
||||
if seq.is_finished():
|
||||
batch_sequences.remove(seq)
|
||||
|
||||
# At this point, we are always at the case where we have finished
|
||||
# processing some number of requests from the batch after running
|
||||
# several _no_spec executions. The memory should not have
|
||||
# increased between the previous time this was recorded and the
|
||||
# current time.
|
||||
if previous_memory_allocated is None:
|
||||
previous_memory_allocated = torch.cuda.memory_allocated()
|
||||
else:
|
||||
assert previous_memory_allocated == torch.cuda.memory_allocated()
|
||||
@ -1,205 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
|
||||
|
||||
def test_initial_call_returns_none():
|
||||
"""Expect first call to get metrics to return None.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert maybe_metrics is None
|
||||
|
||||
|
||||
def test_second_call_returns_metrics():
|
||||
"""Expect second call to not return None.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
|
||||
def test_nonzero_rank_noop(rank):
|
||||
"""Verify nonzero ranks don't collect metrics.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=rank)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
|
||||
def test_noop_until_time():
|
||||
"""Verify metrics aren't collected until enough time passes.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
|
||||
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
def test_timer_is_reset():
|
||||
"""Verify that the internal timer inside AsyncMetricsCollector
|
||||
is reset after collection.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0,
|
||||
collect_interval_s + 0.1,
|
||||
collect_interval_s + 0.1,
|
||||
collect_interval_s + 0.2,
|
||||
collect_interval_s + 0.2,
|
||||
2 * collect_interval_s + 0.1,
|
||||
2 * collect_interval_s + 0.1,
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_data", [True, False])
|
||||
def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
"""Test correctness of metrics data.
|
||||
"""
|
||||
if has_data:
|
||||
num_accepted_tokens = 103
|
||||
num_emitted_tokens = 104
|
||||
num_draft_tokens = 105
|
||||
else:
|
||||
num_accepted_tokens = 0
|
||||
num_emitted_tokens = 0
|
||||
num_draft_tokens = 0
|
||||
k = 5
|
||||
|
||||
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||
num_draft_tokens, k)
|
||||
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = num_draft_tokens
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k)
|
||||
|
||||
assert metrics.num_spec_tokens == k
|
||||
assert metrics.accepted_tokens == num_accepted_tokens
|
||||
assert metrics.draft_tokens == num_draft_tokens
|
||||
assert metrics.emitted_tokens == num_emitted_tokens
|
||||
|
||||
if has_data:
|
||||
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
|
||||
num_draft_tokens)
|
||||
assert (metrics.system_efficiency == num_emitted_tokens /
|
||||
max_num_emitted_tokens)
|
||||
else:
|
||||
assert math.isnan(metrics.draft_acceptance_rate)
|
||||
assert math.isnan(metrics.system_efficiency)
|
||||
@ -1,838 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||
def test_assert_enough_kv_space(num_steps: int):
|
||||
"""Test that the multi step worker checks for sufficient space in the KV
|
||||
cache. It should throw if it cannot run all the steps.
|
||||
"""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
prompts = [
|
||||
list(range(block_size * 3)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
prev_output_tokens = [
|
||||
list(range(block_size * 1)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
inputs = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
worker = MagicMock()
|
||||
worker.model_runner.block_size = block_size
|
||||
|
||||
for seq_group_metadata in inputs:
|
||||
original_block_tables = seq_group_metadata.block_tables
|
||||
|
||||
# No exception.
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = {
|
||||
seq_id: []
|
||||
for seq_id, physical_blocks in original_block_tables.items()
|
||||
}
|
||||
|
||||
# Expect exception.
|
||||
with pytest.raises(ValueError,
|
||||
match='times but found insufficient KV space for'):
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = original_block_tables
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_single_step():
|
||||
"""Verify the multi step worker produces the same output as the normal
|
||||
worker for num_steps=1.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
# multi_step_worker.model_runner = worker.model_runner
|
||||
# multi_step_worker.cache_engine = worker.cache_engine
|
||||
|
||||
num_steps = 1
|
||||
|
||||
prompts = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=single_step_seq_group))[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
]
|
||||
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
|
||||
|
||||
expected_token_ids = [
|
||||
output.samples[0].output_token for output in expected_output
|
||||
]
|
||||
expected_logprobs = [
|
||||
output.samples[0].logprobs for output in expected_output
|
||||
]
|
||||
|
||||
assert actual_token_ids == expected_token_ids
|
||||
|
||||
print(f'{actual_logprobs=}')
|
||||
print(f'{expected_logprobs=}')
|
||||
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_multi_step():
|
||||
"""Verify the multi-step worker produces the same output as the normal
|
||||
worker when num_steps > 1. This test runs the multi-step worker once, and
|
||||
then runs the worker num_steps times, and compares the output.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
# Make sure we go over the block boundary.
|
||||
num_steps = block_size + 1
|
||||
|
||||
random.seed(seed)
|
||||
prompts = [[
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
for _ in multi_step_output:
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
multi_step_output_token_ids[i].append(
|
||||
multi_step[i].samples[0].output_token)
|
||||
single_step_output_token_ids[i].append(
|
||||
single_step[i].samples[0].output_token)
|
||||
|
||||
multi_step_output_logprobs[i].append(
|
||||
multi_step[i].samples[0].logprobs)
|
||||
single_step_output_logprobs[i].append(
|
||||
single_step[i].samples[0].logprobs)
|
||||
|
||||
# Print per-sequence token ids
|
||||
for i, (multi_step_tokens, single_step_tokens) in enumerate(
|
||||
zip(multi_step_output_token_ids, single_step_output_token_ids)):
|
||||
print(f'{i=} {multi_step_tokens=}')
|
||||
print(f'{i=} {single_step_tokens=}')
|
||||
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
|
||||
|
||||
# Assert token ids are equal.
|
||||
for multi_step_tokens, single_step_tokens in zip(
|
||||
multi_step_output_token_ids, single_step_output_token_ids):
|
||||
assert multi_step_tokens == single_step_tokens
|
||||
|
||||
# Assert logprobs are equal.
|
||||
for multi_step_logprobs, single_step_logprobs in zip(
|
||||
multi_step_output_logprobs, single_step_output_logprobs):
|
||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_correct_output():
|
||||
"""
|
||||
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||
tokens correctly. The test verifies that if a sequence has a
|
||||
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||
new sequences corresponding to the sequences with bonus tokens. The
|
||||
expanded batch is then used for predicting the next tokens.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step and verify that the third token prediction is accurate
|
||||
# for all sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
all_seq_ids = {i for i in range(batch_size)}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
"""
|
||||
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||
with a batch containing sequences with bonus tokens but specifies the
|
||||
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||
MultiStepWorker generates correct tokens for the sequences where the
|
||||
sequence ID is specified correctly and incorrect tokens for those where
|
||||
the sequence ID is specified incorrectly.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||
# sequences have bonus tokens. Verify that with this setting the third token
|
||||
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||
# that the prediction might be wrong for some of the even numbered
|
||||
# sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||
num_mismatch = 0
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
if (index % 2) != 0:
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||
num_mismatch += 1
|
||||
# The prediction is accurate for some of the sequences even without proper
|
||||
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||
# for which there is a mismatch is > 0.
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
|
||||
# The choice of backends forces the multi_step_worker to choose between
|
||||
# the vanilla model_runner and TP1DraftModelRunner and that we can test
|
||||
# both code paths.
|
||||
@pytest.mark.parametrize('attn_backend',
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
|
||||
def test_multi_step_correct_kvcache(num_steps, attn_backend):
|
||||
"""Verify that the KV cache of the draft model
|
||||
is correctly updated for sequences with bonus token.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = "JackFram/llama-68m"
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 1
|
||||
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
|
||||
multi_step_worker = create_worker(MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
dtype=dtype)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
dtype=dtype)
|
||||
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
# Already generate two tokens for the sequence
|
||||
# so that we can simulate the bonus token case
|
||||
multi_step_continuations = [[
|
||||
random.randint(0, 1000),
|
||||
random.randint(0, 1000)
|
||||
] for _ in prompts]
|
||||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
|
||||
|
||||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
# Generate the kv cache for the bonus token first
|
||||
single_step_continuations = [c[:1] for c in multi_step_continuations]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=single_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
multi_step_continuations[i].append(
|
||||
seq_group_output.samples[0].output_token)
|
||||
|
||||
# Verify that the KV cache of the single-step and
|
||||
# multi-step workers are the same.
|
||||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
|
||||
num_layers = len(single_step_gpu_cache)
|
||||
allclose = lambda a, b: torch.allclose(
|
||||
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
|
||||
for i in range(num_layers):
|
||||
assert allclose(single_step_gpu_cache[i][0],
|
||||
multi_step_gpu_cache[i][0])
|
||||
assert allclose(single_step_gpu_cache[i][1],
|
||||
multi_step_gpu_cache[i][1])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
small_prompt_len = 5
|
||||
long_prompt_len = 10
|
||||
prev_output_token_len = 20
|
||||
|
||||
expected_num_proposal_seqs = 6
|
||||
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
|
||||
|
||||
prompt_len = [
|
||||
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [long_prompt_len
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(expected_num_proposal_seqs, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [
|
||||
k for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_use_draft_model_runner_advance_step():
|
||||
"""Verify that draft model runner triggers advance step
|
||||
when applicable.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
k = 5
|
||||
batch_size = 32
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
# Mock "_gpu_advance_step" to raise an exception when called.
|
||||
exception_secret = "artificial stop"
|
||||
worker.model_runner._gpu_advance_step = MagicMock()
|
||||
worker.model_runner._gpu_advance_step.side_effect = ValueError(
|
||||
exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Fallback (should not call) when num_steps=1.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=1)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# Expect exception if _gpu_advance_step is called.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_expand_execute_model_request_sync_with_expand_hidden_states():
|
||||
"""
|
||||
In this test we verify that the logic for expanding the
|
||||
seq_group_metadata_list remains in sync with the expansion logic of
|
||||
the HiddenStates in _expand_execute_model_request.
|
||||
"""
|
||||
k = 5
|
||||
batch_size = 16
|
||||
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
execute_model_request = ExecuteModelRequest(
|
||||
seq_group_metadata_list,
|
||||
previous_hidden_states=HiddenStates(
|
||||
torch.arange(batch_size), seq_group_metadata_list,
|
||||
torch.arange(batch_size, 2 * batch_size)))
|
||||
|
||||
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
|
||||
_expand_execute_model_request(execute_model_request,
|
||||
seq_with_bonus_token_in_last_step)
|
||||
|
||||
all_seq_ids = torch.tensor(
|
||||
get_all_seq_ids(
|
||||
expanded_execute_model_request.seq_group_metadata_list))
|
||||
ref_expanded_hidden_states = all_seq_ids + batch_size
|
||||
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
|
||||
|
||||
assert (ref_expanded_hidden_states == expanded_execute_model_request.
|
||||
previous_hidden_states.hidden_states).all().item()
|
||||
@ -1,221 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import create_seq_group_metadata_from_prompts, create_worker
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario cannot find any candidate in one single batch
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([1])
|
||||
assert proposals.proposal_lens.tolist() == [0]
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find some candidate not full in batchs
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
# shall find no candidate as exceed max_proposal_len
|
||||
[
|
||||
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 31, 32, 33
|
||||
],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
# the first sequence has no match so proposal_len should be overwritten to 0
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [0] + [proposal_len for _ in range(3)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == -1
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
assert proposals.proposal_token_ids[4][i] == -1
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batches
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Normally drafter is run on decode requests only; here we check the output
|
||||
# of the ngram worker as it is the sole proposer that has no forward.
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([3])
|
||||
|
||||
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
|
||||
@ -1,116 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import create_batch, create_worker
|
||||
|
||||
|
||||
def create_proposal(propose_lens: list[int], vocab_size: int,
|
||||
device: str) -> SpeculativeProposals:
|
||||
batch_size = len(propose_lens)
|
||||
max_propose_len = max(propose_lens)
|
||||
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
|
||||
device=device)
|
||||
|
||||
proposal_token_ids = torch.full((batch_size, max_propose_len),
|
||||
fill_value=-1,
|
||||
device=device)
|
||||
for i in range(batch_size):
|
||||
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
|
||||
proposal_probs[i][:propose_lens[i]], dim=-1)
|
||||
|
||||
propose_lens = torch.tensor(propose_lens, device=device)
|
||||
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||
propose_lens)
|
||||
|
||||
|
||||
def assert_score_equal(score1: SpeculativeScores,
|
||||
score2: SpeculativeScores) -> None:
|
||||
assert torch.allclose(score1.probs, score2.probs)
|
||||
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||
assert torch.equal(
|
||||
score1.token_ids,
|
||||
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
@pytest.mark.parametrize('prefill_chunking', [False, True])
|
||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
mixed_propose_len: bool, device: str,
|
||||
prefill_chunking: bool) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||
We test for both queries with the same propose length and different
|
||||
propose length, as well as mixed prefill-decode batches.
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||
num_gpu_blocks, seed)
|
||||
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
|
||||
scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
|
||||
scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
||||
|
||||
vocab_size = scorer_worker.vocab_size
|
||||
|
||||
if not mixed_propose_len:
|
||||
propose_lens = [max_propose_len] * batch_size
|
||||
else:
|
||||
# There must be at least 1 decode request, otherwise
|
||||
# we have nothing to score (`_run_no_spec`).
|
||||
non_zero_cnt = random.randint(1, batch_size)
|
||||
propose_lens = [max_propose_len
|
||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||
random.shuffle(propose_lens)
|
||||
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
max_propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
if mixed_propose_len and prefill_chunking and (n_prefills :=
|
||||
batch_size - non_zero_cnt):
|
||||
prefill, _, _ = create_batch(n_prefills,
|
||||
None,
|
||||
prefill_chunk_size=4,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
seq_ids=list(
|
||||
range(batch_size,
|
||||
batch_size + n_prefills)))
|
||||
# re-order to guarantee prefill|decode order
|
||||
target_group_metadatalist = [
|
||||
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
|
||||
if p > 0
|
||||
]
|
||||
seq_group_metadatalist = prefill + target_group_metadatalist
|
||||
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=max_propose_len)
|
||||
|
||||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||
vocab_size)
|
||||
batch_expansion_score = batch_expansion_scorer.score_proposals(
|
||||
requests, proposals)
|
||||
|
||||
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
|
||||
mqa_score = mqa_scorer.score_proposals(requests, proposals)
|
||||
|
||||
assert_score_equal(batch_expansion_score, mqa_score)
|
||||
@ -1,945 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import (create_batch, create_sampler_output_list, create_worker,
|
||||
mock_worker)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
actual_execute_model_data = args[0]
|
||||
assert actual_execute_model_data == execute_model_req
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_batch_expansion_correctly_calls_target_model(
|
||||
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs with batch expansion. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_mqa_scorer=True)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts: list[list[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for _, kwargs in call_args_list:
|
||||
seq_group_metadata_list = kwargs[
|
||||
"execute_model_req"].seq_group_metadata_list
|
||||
|
||||
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts: list[list[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
|
||||
for i in range(len(draft_tokens) + 1):
|
||||
expected_seen_contexts.append(prompt + prev_generated +
|
||||
draft_tokens[:i])
|
||||
|
||||
seen_contexts.sort()
|
||||
expected_seen_contexts.sort()
|
||||
assert expected_seen_contexts == seen_contexts
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
|
||||
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(spec_decode_sampler.call_args_list) == 1
|
||||
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(actual.target_with_bonus_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1))
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
actual_output_by_seq[seq_id].append(sample)
|
||||
|
||||
for step in expected_output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
expected_output_by_seq[seq_id].append(sample)
|
||||
|
||||
all_seen_seq_ids = set(
|
||||
list(actual_output_by_seq.keys()) +
|
||||
list(expected_output_by_seq.keys()))
|
||||
for seq_id in all_seen_seq_ids:
|
||||
actual_by_step = actual_output_by_seq[seq_id]
|
||||
expected_by_step = expected_output_by_seq[seq_id]
|
||||
|
||||
for i in range(k + 1):
|
||||
if i >= len(actual_by_step):
|
||||
assert expected_by_step[i].output_token == -1
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
||||
assert len(call_args_list) == 1
|
||||
args, kwargs = call_args_list[0]
|
||||
assert args[0] == k or kwargs.get('k', -1) == k
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_init_device(acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_initialize_cache(acceptance_sampler_method):
|
||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
|
||||
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.determine_num_available_blocks.return_value = (
|
||||
available_gpu_blocks, available_cpu_blocks)
|
||||
target_worker.get_cache_block_size_bytes.return_value = (
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||
|
||||
target_worker.determine_num_available_blocks.assert_called_once()
|
||||
assert num_cpu_blocks == available_cpu_blocks
|
||||
|
||||
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||
target_cache_block_size_bytes, draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks',
|
||||
list(range(20)) + [1024, 1024**2])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
||||
[2 * 2 * 4096, 2 * 2 * 8192])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
||||
allocation in bytes.
|
||||
"""
|
||||
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
||||
draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_populate_seq_ids_with_bonus_tokens():
|
||||
"""
|
||||
Verify that a call to _create_output_sampler_list correctly updates
|
||||
seq_with_bonus_token_in_last_step.
|
||||
|
||||
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||
tokens by the target model in their last forward pass. This state is
|
||||
maintained only for models relying on the KV cache, such as those using
|
||||
the MultiStepWorker.
|
||||
"""
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 10000
|
||||
num_sequences_with_bonus_tokens = 5
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
draft_worker.device = 'cuda'
|
||||
# The sequence_ids attached to each sequence in the batch.
|
||||
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||
assigned_seq_ids = list(range(batch_size))
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
seq_ids=assigned_seq_ids,
|
||||
prev_output_token_len=10)
|
||||
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
accepted_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
seq_group_metadata.request_id].add(seq_id)
|
||||
# Generate a random sample of sequence indexes with bonus tokens
|
||||
seq_indexes_with_bonus_tokens = random.sample(
|
||||
range(batch_size), num_sequences_with_bonus_tokens)
|
||||
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
||||
mask[seq_indexes_with_bonus_tokens] = False
|
||||
# Set the last token ID to -1 for all indices not in
|
||||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||
# those indices.
|
||||
accepted_token_ids[mask, -1:] = -1
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||
# the range [0, batch_size + num_extra_sequence_ids).
|
||||
num_extra_sequence_ids = 10
|
||||
worker._seq_with_bonus_token_in_last_step = set(
|
||||
range(batch_size + num_extra_sequence_ids))
|
||||
worker._create_output_sampler_list(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
prompt_logprobs=None,
|
||||
k=k,
|
||||
stage_times=(0, 0, 0))
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
# batch are retained.
|
||||
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||
# Sequence IDs that are present in the current batch but do not have
|
||||
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||
expected_seq_ids_with_bonus_tokens = \
|
||||
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||
additional_sequence_ids = \
|
||||
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
expected_request_id_seq_ids_mapping
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_handle_finished_requests():
|
||||
"""
|
||||
Test to verify that finished request IDs are appropriately processed to
|
||||
update the internal state of the SpecDecodeWorker.
|
||||
|
||||
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||
requests as finished, and ensures that the corresponding sequence IDs are
|
||||
correctly removed from the internal mappings.
|
||||
"""
|
||||
batch_size = 32
|
||||
k = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector)
|
||||
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||
# request ids and corresponding sequence ids.
|
||||
worker._request_id_seq_id_mapping = \
|
||||
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||
'request-3': {8,9}, 'request-4': {10,11}}
|
||||
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||
# sequence ids.
|
||||
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
# Mark requests with ids request-1 and request-3 as finished.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
finished_requests_ids=['request-1', 'request-3'])
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# Verify that request-1 and request-3 are removed from
|
||||
# request_id_seq_id_mapping
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||
# Verify that all sequence ids corresponding to 'request-1'
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
# For prefill only batches we expect num_lookahead_slots = 0.
|
||||
num_lookahead_slots=k if n_decodes > 0 else 0)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
|
||||
def test_correctly_load_weight_for_eagle():
|
||||
"""
|
||||
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
|
||||
"""
|
||||
seed = 100
|
||||
block_size = 32
|
||||
num_gpu_blocks = 8096 // block_size
|
||||
target_worker = create_worker(
|
||||
Worker,
|
||||
"JackFram/llama-68m",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
draft_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
"abhigoyal/vllm-eagle-llama-68m-random",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False)
|
||||
worker.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_worker.model_runner.model.lm_head.weight.data)
|
||||
assert torch.allclose(
|
||||
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
|
||||
worker.scorer_worker.model_runner.model.lm_head.weight.data)
|
||||
@ -1,150 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import _get_ranks
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||
from vllm.spec_decode.util import (get_sampled_token_logprobs,
|
||||
split_batch_by_proposal_len)
|
||||
|
||||
|
||||
def test_get_all_seq_ids():
|
||||
"""Verify get_all_seq_ids extracts all seq ids.
|
||||
"""
|
||||
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for seq_id in expected_seq_ids
|
||||
]
|
||||
|
||||
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
assert actual_seq_ids == expected_seq_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_sequence_group_metadata():
|
||||
seq_ids = list(range(3))
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
i: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
i: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for i in seq_ids
|
||||
]
|
||||
|
||||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [0, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [1, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_empty_inputs():
|
||||
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||
"""
|
||||
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||
object depending on whether acceptance_sampler_method is
|
||||
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||
"""
|
||||
if acceptance_sampler_method == "rejection_sampler":
|
||||
sampler = MagicMock(spec=RejectionSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||
|
||||
|
||||
def test_get_sampled_token_logprobs():
|
||||
"""Verify get_sampled_token_logprobs returns consistent rankings
|
||||
with regular get_ranks when probabilities match exactly.
|
||||
"""
|
||||
logprob_tensor = torch.tensor(
|
||||
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
|
||||
sampled_token_tensor = torch.tensor([[1,
|
||||
0]]) # shape (num_steps, batch_size)
|
||||
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
|
||||
sampled_token_tensor)
|
||||
|
||||
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
|
||||
sampled_token_tensor.reshape(-1))
|
||||
|
||||
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
|
||||
@ -1,290 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from itertools import count
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
T = TypeVar("T", bound=Worker)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0,
|
||||
use_spec: bool = True) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = Worker
|
||||
|
||||
spec = cls if use_spec else None
|
||||
|
||||
worker = MagicMock(spec=spec)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
worker.device = 'cuda:0'
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
def new_execute_model(*args, **kwargs):
|
||||
result = original_execute_model(*args, **kwargs)
|
||||
set_random_seed(next(seed_iter))
|
||||
return result
|
||||
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: list[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: Callable[..., T],
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[ModelRunner] = None,
|
||||
dtype: Optional[str] = "auto") -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
worker = cls(
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
engine_config.cache_config.num_cpu_blocks = 0
|
||||
worker.initialize_cache(
|
||||
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: list[dict[int, Logprob]],
|
||||
expected_logprobs: list[dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
single_step_expected_logprobs.keys())
|
||||
for token_id in single_step_actual_logprobs:
|
||||
actual = torch.tensor(
|
||||
single_step_actual_logprobs[token_id].logprob)
|
||||
expected = torch.tensor(
|
||||
single_step_expected_logprobs[token_id].logprob)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(range(batch_size))
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: Logprob(0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
logprobs=logprobs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
|
||||
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
|
||||
if prefill_chunk_size > 0:
|
||||
llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
llm_kwargs["enable_chunked_prefill"] = False
|
||||
@ -29,7 +29,6 @@ def test_sampler_output_initialization(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
assert sampler_output.sampled_token_probs is None
|
||||
assert sampler_output.sampled_token_ids is None
|
||||
assert sampler_output.spec_decode_worker_metrics is None
|
||||
|
||||
|
||||
def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||
|
||||
@ -40,12 +40,6 @@ def test_unsupported_configs(monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
kv_cache_dtype="fp8",
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
|
||||
@ -32,6 +32,5 @@ run_mypy vllm/lora
|
||||
run_mypy vllm/model_executor
|
||||
run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/spec_decode
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
|
||||
@ -2536,8 +2536,6 @@ class DeviceConfig:
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp"]
|
||||
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
|
||||
"typical_acceptance_sampler"]
|
||||
|
||||
|
||||
@config
|
||||
@ -2560,13 +2558,6 @@ class SpeculativeConfig:
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
|
||||
"""The method to use for accepting draft tokens:\n
|
||||
- "rejection_sampler" maps to `RejectionSampler`.\n
|
||||
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
|
||||
|
||||
If using `typical_acceptance_sampler`, the related configuration
|
||||
`posterior_threshold` and `posterior_alpha` should be considered."""
|
||||
draft_tensor_parallel_size: Optional[int] = None
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
@ -2593,9 +2584,6 @@ class SpeculativeConfig:
|
||||
will use the default version."""
|
||||
|
||||
# Advanced control
|
||||
disable_mqa_scorer: bool = False
|
||||
"""Disable the MQA scorer and fall back to batch expansion for scoring
|
||||
proposals."""
|
||||
disable_by_batch_size: Optional[int] = None
|
||||
"""Disable speculative decoding for new incoming requests when the number
|
||||
of enqueued requests is larger than this value, if provided."""
|
||||
@ -2608,16 +2596,6 @@ class SpeculativeConfig:
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
# Typical acceptance sampler configuration
|
||||
posterior_threshold: Optional[float] = None
|
||||
"""A threshold value that sets a lower bound on the posterior probability
|
||||
of a token in the target model for it to be accepted. This threshold is
|
||||
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
|
||||
"""
|
||||
posterior_alpha: Optional[float] = None
|
||||
"""Scaling factor for entropy-based threshold, applied when using
|
||||
`TypicalAcceptanceSampler`."""
|
||||
|
||||
speculative_token_tree: Optional[str] = None
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
@ -2795,8 +2773,8 @@ class SpeculativeConfig:
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"deepseek_mtp"):
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
@ -2806,6 +2784,11 @@ class SpeculativeConfig:
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
@ -2864,12 +2847,6 @@ class SpeculativeConfig:
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
if self.acceptance_method == "typical_acceptance_sampler":
|
||||
if self.posterior_threshold is None:
|
||||
self.posterior_threshold = 0.09
|
||||
if self.posterior_alpha is None:
|
||||
self.posterior_alpha = 0.3
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: Optional[int],
|
||||
@ -2975,30 +2952,6 @@ class SpeculativeConfig:
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config)
|
||||
# Validate and set draft token acceptance related settings.
|
||||
|
||||
if self.acceptance_method is None:
|
||||
raise ValueError("acceptance_method is not set. "
|
||||
"Expected values are rejection_sampler or "
|
||||
"typical_acceptance_sampler.")
|
||||
|
||||
if (self.acceptance_method != 'rejection_sampler'
|
||||
and self.acceptance_method != 'typical_acceptance_sampler'):
|
||||
raise ValueError(
|
||||
"Expected acceptance_method to be either "
|
||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||
f"is {self.acceptance_method}")
|
||||
|
||||
if self.acceptance_method == "typical_acceptance_sampler" and (
|
||||
(self.posterior_threshold is not None
|
||||
and self.posterior_threshold < 0) or
|
||||
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
|
||||
raise ValueError(
|
||||
"Expected the posterior_threshold and posterior_alpha of "
|
||||
"typical_acceptance_sampler to be > 0. "
|
||||
"Instead found posterior_threshold = "
|
||||
f"{self.posterior_threshold} and posterior_alpha = "
|
||||
f"{self.posterior_alpha}")
|
||||
|
||||
if (self.disable_by_batch_size is not None
|
||||
and self.disable_by_batch_size < 2):
|
||||
|
||||
@ -1417,28 +1417,12 @@ class EngineArgs:
|
||||
return False
|
||||
|
||||
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
|
||||
is_ngram_enabled = False
|
||||
is_eagle_enabled = False
|
||||
is_medusa_enabled = False
|
||||
if self.speculative_config is not None:
|
||||
# This is supported but experimental (handled below).
|
||||
speculative_method = self.speculative_config.get("method")
|
||||
if speculative_method:
|
||||
if speculative_method in ("ngram", "[ngram]"):
|
||||
is_ngram_enabled = True
|
||||
elif speculative_method == "medusa":
|
||||
is_medusa_enabled = True
|
||||
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
|
||||
is_eagle_enabled = True
|
||||
else:
|
||||
speculative_model = self.speculative_config.get("model")
|
||||
if speculative_model in ("ngram", "[ngram]"):
|
||||
is_ngram_enabled = True
|
||||
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
|
||||
# Other speculative decoding methods are not supported yet.
|
||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
if (self.speculative_config is not None
|
||||
and self.speculative_config.get("method") == "draft_model"):
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not supported yet. "
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
|
||||
# No XFormers so far.
|
||||
V1_BACKENDS = [
|
||||
|
||||
@ -1780,13 +1780,6 @@ class LLMEngine:
|
||||
num_generation_tokens_from_prefill_groups)
|
||||
num_tokens_iter = (num_generation_tokens_iter +
|
||||
num_prompt_tokens_iter)
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
if model_output and isinstance(model_output[0], SamplerOutput) and (
|
||||
model_output[0].spec_decode_worker_metrics is not None):
|
||||
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||
else:
|
||||
spec_decode_metrics = None
|
||||
|
||||
return Stats(
|
||||
now=now,
|
||||
@ -1808,7 +1801,6 @@ class LLMEngine:
|
||||
num_tokens_iter=num_tokens_iter,
|
||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
||||
spec_decode_metrics=spec_decode_metrics,
|
||||
num_preemption_iter=num_preemption_iter,
|
||||
|
||||
# Request stats
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Optional, Type, Union, cast
|
||||
|
||||
@ -19,9 +18,6 @@ if ray is not None:
|
||||
else:
|
||||
ray_metrics = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
prometheus_client.disable_created_metrics()
|
||||
@ -199,30 +195,6 @@ class Metrics:
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Speculative decoding stats
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
||||
name="vllm:spec_decode_draft_acceptance_rate",
|
||||
documentation="Speulative token acceptance rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
||||
name="vllm:spec_decode_efficiency",
|
||||
documentation="Speculative decoding system efficiency.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_emitted_tokens_total",
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
|
||||
|
||||
# --8<-- [end:metrics-definitions]
|
||||
|
||||
@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
stats.gpu_prefix_cache_hit_rate * 100,
|
||||
stats.cpu_prefix_cache_hit_rate * 100,
|
||||
)
|
||||
if self.spec_decode_metrics is not None:
|
||||
log_fn(
|
||||
self._format_spec_decode_metrics_str(
|
||||
self.spec_decode_metrics))
|
||||
|
||||
self._reset(stats, prompt_throughput, generation_throughput)
|
||||
|
||||
@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
self.spec_decode_metrics = None
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
self.last_generation_throughput = generation_throughput
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens: {metrics.emitted_tokens}.")
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
if self.spec_decode_metrics is not None:
|
||||
self._log_gauge(
|
||||
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
||||
self.spec_decode_metrics.draft_acceptance_rate)
|
||||
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
||||
self.spec_decode_metrics.system_efficiency)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_accepted_tokens,
|
||||
self.spec_decode_metrics.accepted_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_draft_tokens,
|
||||
self.spec_decode_metrics.draft_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_emitted_tokens,
|
||||
self.spec_decode_metrics.emitted_tokens)
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
self.spec_decode_metrics = None
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
||||
|
||||
@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -65,8 +64,6 @@ class Stats:
|
||||
running_lora_adapters: List[str]
|
||||
max_lora: str
|
||||
|
||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
"""Base class for StatLogger."""
|
||||
@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
|
||||
self.num_generation_tokens: List[int] = []
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
@abstractmethod
|
||||
def log(self, stats: Stats) -> None:
|
||||
@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
|
||||
@abstractmethod
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
||||
"""Save spec decode metrics (since they are unlikely
|
||||
to be emitted at same time as log interval)."""
|
||||
if stats.spec_decode_metrics is not None:
|
||||
self.spec_decode_metrics = stats.spec_decode_metrics
|
||||
|
||||
@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
seqs = sequence_group.get_seqs(
|
||||
status=SequenceStatus.FINISHED_ABORTED)
|
||||
|
||||
for output in outputs:
|
||||
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
|
||||
sequence_group.metrics.spec_token_acceptance_counts[
|
||||
output.step_index] += 1
|
||||
|
||||
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
||||
assert len(seqs) == 1, (
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
|
||||
@ -1,406 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeStochasticBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if find_spec("flashinfer"):
|
||||
"""
|
||||
Consider utilizing the FlashInfer rejection sampling kernel initially,
|
||||
as it employs a dedicated kernel rather than relying on
|
||||
Torch tensor operations. This design choice helps to fuse operations,
|
||||
reduce memory I/O, and consequently enhances performance.
|
||||
"""
|
||||
from flashinfer.sampling import chain_speculative_sampling
|
||||
else:
|
||||
chain_speculative_sampling = None
|
||||
|
||||
|
||||
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
Language Model Decoding with Speculative Sampling"
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strict_mode: bool = False,
|
||||
use_flashinfer: Optional[bool] = None):
|
||||
"""Create a rejection sampler.
|
||||
|
||||
Args:
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
use_flashinfer: We will use this parameter to determine whether
|
||||
to use the FlashInfer rejection sampling kernel or not. If it's
|
||||
None, we will use the default value from the environment variable.
|
||||
This parameter is only used for testing purposes.
|
||||
"""
|
||||
super().__init__(strict_mode=strict_mode)
|
||||
if use_flashinfer is None:
|
||||
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
||||
chain_speculative_sampling is not None)
|
||||
else:
|
||||
self.use_flashinfer = use_flashinfer
|
||||
|
||||
if self.use_flashinfer:
|
||||
logger.info("Use flashinfer for rejection sampling.")
|
||||
else:
|
||||
logger.info("Use pytorch for rejection sampling.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||
tokens proposed by the draft model using the probability of each token
|
||||
according to the draft and target models.
|
||||
|
||||
In the worst case where all draft tokens are rejected, it is guaranteed
|
||||
one correct token will be emitted.
|
||||
|
||||
In the case where all draft tokens are accepted, a bonus token will be
|
||||
accepted as its cheap to have the target model score this speculative
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
target_with_bonus_probs: The probability distribution
|
||||
over token ids given context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
|
||||
|
||||
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
||||
speculative tokens in a sequence are accepted.
|
||||
shape = [batch_size, num_bonus_tokens]
|
||||
|
||||
draft_probs: The probability distribution over token ids given
|
||||
context according to the draft model.
|
||||
shape = [batch_size, num_speculative_tokens, vocab_size]
|
||||
|
||||
draft_token_ids: The token ids that were sampled from the draft
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
||||
seeded_seqs: Dict of batch row index to torch generator, for
|
||||
sequences using seeded generation.
|
||||
|
||||
Returns:
|
||||
output_token_ids: The token ids sampled via rejection sampling,
|
||||
or -1 if unable to sample a token because the previous token
|
||||
was rejected.
|
||||
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
||||
"""
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
|
||||
# batch_size = 0 when all requests in the batch are
|
||||
# non_spec requests. In this case, output_token_ids is
|
||||
# just an empty tensor.
|
||||
if batch_size == 0:
|
||||
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
|
||||
|
||||
# If use Flashinfer chain_speculative_sampling kernel
|
||||
# for rejection sampling
|
||||
if self.use_flashinfer and chain_speculative_sampling is not None:
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
|
||||
(output_token_ids, accepted_token_num,
|
||||
emitted_token_num) = chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
target_with_bonus_probs,
|
||||
)
|
||||
|
||||
# num_emitted_tokens returned by flashinfer
|
||||
# does not include the bonus token
|
||||
# Flashinfer stops at the first token that violates
|
||||
# the condition p >= q and does not include recovery/bonus token.
|
||||
# Therefore, we need to add batch_size here.
|
||||
self.num_accepted_tokens += accepted_token_num.sum()
|
||||
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
|
||||
self.num_draft_tokens += batch_size * k
|
||||
else:
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
target_with_bonus_probs[:, :-1],
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
seeded_seqs,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
return output_token_ids
|
||||
|
||||
def _batch_modified_rejection_sampling(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform modified rejection sampling on each sequence.
|
||||
|
||||
Returns:
|
||||
A tuple of two tensors:
|
||||
0: A bool tensor of which tokens in each sequence is accepted.
|
||||
shape = [batch_size, k]
|
||||
1: Token ids sampled from a recovered distribution, to be used
|
||||
when a token is rejected.
|
||||
shape = [batch_size, k]
|
||||
"""
|
||||
|
||||
batch_size, k, vocab_size = draft_probs.shape
|
||||
|
||||
# shape [batch_size, k]
|
||||
accepted = self._get_accepted(target_probs, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
|
||||
recovered_probs = self._get_recovered_probs(
|
||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||
|
||||
# NOTE: the recovered_probs are overwritten by this method.
|
||||
recovered_token_ids = _multinomial(
|
||||
recovered_probs,
|
||||
num_samples=1,
|
||||
k=k,
|
||||
seeded_seqs=seeded_seqs or {},
|
||||
).reshape(batch_size, k)
|
||||
|
||||
return accepted, recovered_token_ids
|
||||
|
||||
def _create_uniform_samples(self,
|
||||
seeded_seqs: Optional[dict[int,
|
||||
torch.Generator]],
|
||||
batch_size: int, k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
for specific sequences.
|
||||
|
||||
This method creates a tensor of shape `(batch_size, k + 1)` filled
|
||||
with uniform random values in the range [0, 1). If `seeded_seqs`
|
||||
is provided, the sequences corresponding to specific indices
|
||||
will be generated using the provided `torch.Generator` for
|
||||
reproducibility. The other sequences will be generated without
|
||||
a seed.
|
||||
|
||||
Args:
|
||||
seeded_seqs : Optional[dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects. If `None`, all samples are
|
||||
generated without a seed.
|
||||
batch_size : int
|
||||
The number of sequences to generate.
|
||||
k : int
|
||||
The number of random samples per sequence.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(batch_size, k + 1)` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
if not seeded_seqs:
|
||||
return torch.rand(batch_size, k + 1, device=device)
|
||||
|
||||
uniform_rand = torch.empty(batch_size, k + 1, device=device)
|
||||
|
||||
non_seeded_indices = []
|
||||
for idx in range(batch_size):
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.append(idx)
|
||||
else:
|
||||
uniform_rand[idx, :] = torch.rand(1,
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device,
|
||||
generator=generator)
|
||||
if non_seeded_indices:
|
||||
uniform_rand[non_seeded_indices, :] = torch.rand(
|
||||
len(non_seeded_indices),
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device)
|
||||
return uniform_rand
|
||||
|
||||
def _get_accepted(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
r"""Create bool matrix over the proposed draft tokens. If
|
||||
True, then a token can be accepted, else it should be
|
||||
rejected.
|
||||
|
||||
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
|
||||
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
|
||||
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
|
||||
same conditional probability according to the draft model, the token
|
||||
is accepted with probability:
|
||||
|
||||
$$
|
||||
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
|
||||
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
|
||||
$$
|
||||
|
||||
This implementation does not apply causality. When using the output,
|
||||
if a token is rejected, subsequent tokens should not be used.
|
||||
|
||||
Returns a bool tensor of shape [batch_size, k] specifying which tokens
|
||||
are accepted.
|
||||
"""
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
batch_indices = torch.arange(batch_size,
|
||||
device=target_probs.device)[:, None]
|
||||
probs_indices = torch.arange(k, device=target_probs.device)
|
||||
|
||||
# shape [batch_size, k]
|
||||
selected_draft_probs = draft_probs[batch_indices, probs_indices,
|
||||
draft_token_ids]
|
||||
|
||||
# shape [batch_size, k]
|
||||
selected_target_probs = target_probs[batch_indices, probs_indices,
|
||||
draft_token_ids]
|
||||
|
||||
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
|
||||
k - 1, target_probs.device)
|
||||
|
||||
capped_ratio = torch.minimum(
|
||||
selected_target_probs / selected_draft_probs,
|
||||
torch.full((1, ), 1, device=target_probs.device))
|
||||
accepted = uniform_rand < capped_ratio
|
||||
|
||||
return accepted
|
||||
|
||||
def _get_recovered_probs(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [k, vocab_size]
|
||||
) -> torch.Tensor:
|
||||
r"""Create a probability distribution for each proposed token which can
|
||||
be sampled if the proposed token is rejected.
|
||||
|
||||
When this routine is applied sequentially, the true distribution of the
|
||||
target model is recovered (within hardware numerics).
|
||||
|
||||
The probability distribution used in this rejection case is constructed
|
||||
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
|
||||
$x$ given context $x_1, \dots, x_n$ according to the target
|
||||
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
|
||||
according to the draft model:
|
||||
|
||||
$$
|
||||
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
|
||||
$$
|
||||
|
||||
where $(f(x))_+$ is defined as:
|
||||
|
||||
$$
|
||||
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
|
||||
$$
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
|
||||
of the draft, target, and recovered probability distributions.
|
||||
|
||||
Returns a tensor of shape [batch_size, k, vocab_size].
|
||||
|
||||
Note:
|
||||
This batches operations on GPU and thus constructs the recovered
|
||||
distribution for all tokens, even if they are accepted. This causes
|
||||
division-by-zero errors, so we use self._smallest_positive_value to
|
||||
avoid that. This introduces some drift to the distribution.
|
||||
"""
|
||||
_, k, _ = draft_probs.shape
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
difference = target_probs - draft_probs
|
||||
|
||||
# TODO(cade): Can we use logprobs instead of probs, and avoid the
|
||||
# division-by-zero errors without introducing distribution drift?
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
f = torch.clamp(difference, min=self._smallest_positive_value)
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
|
||||
|
||||
return recovered_probs
|
||||
|
||||
@cached_property
|
||||
def _smallest_positive_value(self) -> float:
|
||||
"""Return the smallest positive value representable by the probs dtype.
|
||||
This value is used when constructing a distribution from which to sample
|
||||
recovered tokens in the first rejection case.
|
||||
|
||||
See _get_recovered_probs for more details
|
||||
|
||||
Note that this isn't actually the smallest positive value representable
|
||||
by float32, but the smallest positive normal value.
|
||||
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
|
||||
"""
|
||||
return torch.finfo(self.probs_dtype).tiny
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||
# Note that we always sample with replacement.
|
||||
# probs will be modified in place, but this is fine, as we pass
|
||||
# in a copy already.
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
k: int,
|
||||
seeded_seqs: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
# forces a GPU<->CPU sync).
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
q = torch.empty_like(probs)
|
||||
if not seeded_seqs:
|
||||
q.exponential_(1.0)
|
||||
else:
|
||||
start = 0
|
||||
for idx in range(len(q) // k):
|
||||
end = start + k
|
||||
generator = seeded_seqs.get(idx)
|
||||
# Note: generator might be None for non seeded
|
||||
q[start:end].exponential_(1.0, generator=generator)
|
||||
start = end
|
||||
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
# yapf: disable
|
||||
@ -119,9 +118,6 @@ class SamplerOutput(
|
||||
# specified in lieu of prompt token ids or text.
|
||||
sampled_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
@ -159,11 +155,9 @@ class SamplerOutput(
|
||||
else self.sampled_token_probs.shape)
|
||||
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
||||
self.sampled_token_ids.shape)
|
||||
return (
|
||||
f"SamplerOutput(outputs={self.outputs}, "
|
||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr}, "
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
return (f"SamplerOutput(outputs={self.outputs}, "
|
||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr})")
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
@ -1,259 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class SpecDecodeBaseSampler(nn.Module):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step.
|
||||
"""
|
||||
|
||||
def __init__(self, strict_mode: bool = False):
|
||||
"""Base class constructor.
|
||||
Args:
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
# accepted. There is always only one possible bonus token. We store this
|
||||
# value in a variable for readability.
|
||||
self._num_bonus_tokens = 1
|
||||
|
||||
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
if isinstance(device, int):
|
||||
device = f"{current_platform.device_type}:{device}"
|
||||
elif not isinstance(device, str):
|
||||
raise ValueError(f"Device must be int or str, get {type(device)}")
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
def init_tensors(self,
|
||||
device: Union[int, str],
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if isinstance(device, int):
|
||||
device = f"{device_type}:{device}"
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
substitute_token_ids: torch.Tensor, # [batch_size, k]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
bonus_token_ids: torch.Tensor, # [batch_size]
|
||||
) -> torch.Tensor:
|
||||
"""Format output. Returns a matrix of token ids. When
|
||||
a token is rejected via sampling, all subsequent token ids are
|
||||
set to -1 for the sequence.
|
||||
|
||||
Args:
|
||||
accepted: A boolean tensor indicating if the corresponding
|
||||
draft token in draft_token_ids should be accepted or not.
|
||||
substitute_token_ids: A tensor of token_ids that can be used
|
||||
as substitutes for the draft token ids if the proposed token
|
||||
is rejected.
|
||||
draft_token_ids: A tensor of token ids speculated by the
|
||||
draft model.
|
||||
bonus_token_ids: Token ids to use as the bonus token if
|
||||
all the draft tokens are accepted.
|
||||
Returns:
|
||||
A tensor containing the accepted token ids. The shape of the
|
||||
tensor is [batch_size, k + num_bonus_tokens]
|
||||
"""
|
||||
batch_size, k = substitute_token_ids.shape
|
||||
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
||||
# Determine the index of the first False value for each row.
|
||||
limits = (accepted == 0).max(1).indices
|
||||
limits[~(accepted == 0).any(1)] = k
|
||||
|
||||
# Create masks using the indices.
|
||||
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
||||
accepted_mask = indices < limits.unsqueeze(1)
|
||||
after_false_mask = indices == limits.unsqueeze(1)
|
||||
|
||||
# Create an extended output tensor
|
||||
output_with_bonus_tokens = -torch.ones(
|
||||
(batch_size, k + self._num_bonus_tokens),
|
||||
dtype=self.token_id_dtype,
|
||||
device=accepted.device)
|
||||
output = output_with_bonus_tokens[:, :k]
|
||||
|
||||
# Fill in the first k columns of the output tensor using masks and data
|
||||
# tensors.
|
||||
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
||||
-torch.ones_like(draft_token_ids))
|
||||
|
||||
# Fill the last column.
|
||||
# We check output directly as accepted may have True values inconsistent
|
||||
# with causal acceptance.
|
||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||
bonus_token_ids, -1)
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
substitute_token_ids.mul(after_false_mask))
|
||||
|
||||
self.num_accepted_tokens += accepted.sum()
|
||||
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
||||
self.num_draft_tokens += batch_size * k
|
||||
|
||||
return output_with_bonus_tokens
|
||||
|
||||
def _raise_if_incorrect_input(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
||||
draft_token_ids, bonus_token_ids)
|
||||
|
||||
def _raise_if_incorrect_shape(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
(target_batch_size, num_target_probs,
|
||||
target_vocab_size) = target_with_bonus_probs.shape
|
||||
|
||||
# Does not count the extra token
|
||||
num_target_probs -= 1
|
||||
|
||||
# validate the shape of draft token ids.
|
||||
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
||||
assert draft_token_ids_batch_size == target_batch_size
|
||||
assert num_draft_token_ids == num_target_probs
|
||||
|
||||
# validate the shape of bonus token ids
|
||||
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
||||
assert bonus_batch_size == target_batch_size
|
||||
assert num_bonus_tokens == self._num_bonus_tokens
|
||||
|
||||
# validate the shape of draft probs if it is set
|
||||
if draft_probs is not None:
|
||||
(draft_batch_size, num_draft_probs,
|
||||
draft_vocab_size) = draft_probs.shape
|
||||
assert draft_batch_size == target_batch_size
|
||||
assert num_draft_probs == num_target_probs
|
||||
assert (draft_vocab_size == target_vocab_size
|
||||
), f"{draft_vocab_size=} {target_vocab_size=}"
|
||||
|
||||
def _raise_if_incorrect_dtype(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
assert target_with_bonus_probs.dtype == self.probs_dtype
|
||||
assert draft_token_ids.dtype == self.token_id_dtype
|
||||
assert bonus_token_ids.dtype == self.token_id_dtype
|
||||
if draft_probs is not None:
|
||||
assert draft_probs.dtype == self.probs_dtype
|
||||
|
||||
def _raise_if_inconsistent_device(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
devices = [
|
||||
t.device for t in [
|
||||
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids
|
||||
] if t is not None
|
||||
]
|
||||
assert all([devices[0] == device for device in devices])
|
||||
|
||||
def _raise_if_out_of_bounds_vocab(
|
||||
self,
|
||||
vocab_size: int,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
assert torch.all(bonus_token_ids < vocab_size)
|
||||
assert torch.all(bonus_token_ids >= 0)
|
||||
assert torch.all(draft_token_ids < vocab_size)
|
||||
assert torch.all(draft_token_ids >= 0)
|
||||
|
||||
|
||||
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are deterministic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are stochastic
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@ -1,166 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeDeterministicBaseSampler)
|
||||
|
||||
|
||||
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||
Multiple Decoding Heads"
|
||||
https://arxiv.org/pdf/2401.10774
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
posterior_threshold: float,
|
||||
posterior_alpha: float,
|
||||
strict_mode: bool = False,
|
||||
):
|
||||
"""Create a Typical Acceptance Sampler.
|
||||
|
||||
Args:
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
posterior_threshold : A threshold value that sets a lower bound
|
||||
on the posterior probability of a token in target model for it
|
||||
to be accepted.
|
||||
posterior_alpha : A scaling factor for the entropy-based
|
||||
threshold in typical acceptance sampling.
|
||||
"""
|
||||
self._posterior_threshold = posterior_threshold
|
||||
self._posterior_alpha = posterior_alpha
|
||||
super().__init__(strict_mode=strict_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using typical acceptance sampling. This accepts
|
||||
or rejects tokens proposed by the draft model using the probability
|
||||
of each token according to the draft and target models.
|
||||
|
||||
In the worst case where all draft tokens are rejected, it is guaranteed
|
||||
one token will be emitted.
|
||||
|
||||
In the case where all draft tokens are accepted, the bonus token will be
|
||||
accepted.
|
||||
|
||||
Args:
|
||||
target_probs: The probability distribution over token ids given
|
||||
context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens, vocab_size]
|
||||
|
||||
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
||||
speculative tokens in a sequence are accepted.
|
||||
shape = [batch_size, num_bonus_tokens]
|
||||
|
||||
draft_probs: This parameter is unused by the acceptance sampler.
|
||||
|
||||
draft_token_ids: The token ids that were sampled from the draft
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
||||
Returns:
|
||||
output_token_ids: The token ids sampled via rejection sampling,
|
||||
or -1 if unable to sample a token because the previous token
|
||||
was rejected.
|
||||
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
||||
"""
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids)
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
accepted = self._evaluate_accepted_tokens(target_probs,
|
||||
draft_token_ids)
|
||||
recovered_token_ids = self._get_recovered_token_ids(target_probs)
|
||||
output_token_ids = self._create_output(accepted, recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids)
|
||||
return output_token_ids
|
||||
|
||||
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
|
||||
r"""
|
||||
Evaluates and returns a mask of accepted tokens based on the
|
||||
posterior probabilities.
|
||||
|
||||
Args:
|
||||
target_probs (torch.Tensor): A tensor of shape
|
||||
(batch_size, k, vocab_size) representing the probabilities of
|
||||
each token in the vocabulary for each position in the proposed
|
||||
sequence. This is the distribution generated by the target
|
||||
model.
|
||||
draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
|
||||
representing the proposed token ids.
|
||||
|
||||
A draft token_id x_{n+k} is accepted if it satisfies the
|
||||
following condition
|
||||
|
||||
$$
|
||||
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
|
||||
\min \left( \epsilon, \delta * \exp \left(
|
||||
-H(p_{\text{original}}(
|
||||
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
|
||||
$$
|
||||
|
||||
where $p_{\text{original}}$ corresponds to target_probs
|
||||
and $\epsilon$ and $\delta$ correspond to hyperparameters
|
||||
specified using self._posterior_threshold and self._posterior_alpha
|
||||
|
||||
This method computes the posterior probabilities for the given
|
||||
draft token ids based on the provided target probabilities. It
|
||||
calculates the entropy of the posterior distribution and determines
|
||||
a dynamic threshold for each token position using the provided
|
||||
posterior_threshold and posterior_alpha values. The method then
|
||||
returns a boolean mask indicating which tokens can be accepted.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A boolean tensor of shape (batch_size, k) where each
|
||||
element indicates whether the corresponding draft token has
|
||||
been accepted or rejected. True indicates acceptance and false
|
||||
indicates rejection.
|
||||
"""
|
||||
device = target_probs.device
|
||||
candidates_prob = torch.gather(
|
||||
target_probs, dim=-1,
|
||||
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
|
||||
# A small constant added to prevent computing the logarithm of zero,
|
||||
# which can lead to undefined values.
|
||||
epsilon = 1e-5
|
||||
posterior_entropy = -torch.sum(
|
||||
target_probs * torch.log(target_probs + epsilon), dim=-1)
|
||||
threshold = torch.minimum(
|
||||
torch.ones_like(posterior_entropy, device=device) *
|
||||
self._posterior_threshold,
|
||||
torch.exp(-posterior_entropy) * self._posterior_alpha,
|
||||
)
|
||||
accepted_mask = candidates_prob > threshold
|
||||
return accepted_mask
|
||||
|
||||
def _get_recovered_token_ids(self, target_probs):
|
||||
"""
|
||||
The recovered token ids will fill the first unmatched token
|
||||
by the target token.
|
||||
|
||||
Args:
|
||||
target_probs (torch.Tensor): A tensor of shape
|
||||
(batch_size, k, vocab_size) containing the target probability
|
||||
distribution.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (batch_size, k) with the recovered
|
||||
token ids which are selected from target probs.
|
||||
"""
|
||||
max_indices = torch.argmax(target_probs, dim=-1)
|
||||
|
||||
return max_indices
|
||||
@ -1,261 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DummyInputLayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, weight=None, bias=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight) if weight is not None else None
|
||||
self.bias = nn.Parameter(bias) if bias is not None else None
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class DummyOutputNorm(nn.Module):
|
||||
|
||||
def forward(self, x, residual):
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x + residual, None
|
||||
|
||||
|
||||
class EAGLE(nn.Module):
|
||||
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||
Reference implementation: https://github.com/SafeAILab/EAGLE
|
||||
|
||||
Differences from reference implementation:
|
||||
1. In reference, LlamaDecoderLayer implementation doesn't have
|
||||
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
|
||||
Following this approach, our implementation also disables
|
||||
the input_layernorm for the first decoder layer.
|
||||
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
||||
decoder layer is fixed to be LlamaDecoderLayer.
|
||||
3. We have an optional token_map which reduces draft vocab to most
|
||||
frequently used tokens to give some additional speed-up by reducing
|
||||
sampling overhead. This is disabled unless the checkpoint file has
|
||||
explicit token_map tensor and config has an optional attribute
|
||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||
in the draft checkpoint (using key token_map). Also, the draft config
|
||||
needs to have truncated_vocab_size (=k) as an attribute.
|
||||
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
|
||||
module with regards to the use of additional RMS norms. The original
|
||||
EAGLE architecture 1) skips the pre-attention norm in its first
|
||||
transformer block, and 2) skips the final output norm, both of which we
|
||||
found to be suboptimal. We also add the support for separate norms
|
||||
applying to both the token embedding and hidden states before projection
|
||||
as in DeepSeek MTP, which we found to improve performance as well.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.config = config
|
||||
|
||||
architectures = getattr(self.config.model, "architectures", [])
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
self.model = model_cls(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||
config.model.hidden_size,
|
||||
bias=getattr(self.config, "eagle_fc_bias", False))
|
||||
|
||||
# Modify layer normalization and residual connections as suggested
|
||||
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
|
||||
# While weights and biases are generally not needed,
|
||||
# they are retained here to support certain unit tests
|
||||
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
||||
if not hasattr(self.config.model,
|
||||
"skip_prenorm") or self.config.model.skip_prenorm:
|
||||
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
||||
weight=self.model.model.layers[0].input_layernorm.weight)
|
||||
|
||||
if not hasattr(
|
||||
self.config.model,
|
||||
"skip_output_norm") or self.config.model.skip_output_norm:
|
||||
self.model.model.norm = DummyOutputNorm()
|
||||
|
||||
self.add_para_norm = False
|
||||
if hasattr(self.config.model,
|
||||
"add_para_norm") and self.config.model.add_para_norm:
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.add_para_norm = True
|
||||
|
||||
self.orig_vocab_size = config.vocab_size
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.truncated_vocab_size,
|
||||
logit_scale)
|
||||
|
||||
# Token map is a idx to token mapping to reduce the vocab size for
|
||||
# the draft model. Using smaller vocab size for draft, containing
|
||||
# only most frequent tokens reduces the speculation overhead. This
|
||||
# doesn't affect the acceptance rate much and thus gives more speed
|
||||
# -up. By default, this is disabled and is only used if the EAGLE
|
||||
# checkpoint file has token_map tensor.
|
||||
self.token_map = None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
# Handle both empty previous_hidden_states
|
||||
# and mismatched batch size
|
||||
batch_size = inputs_embeds.size(0)
|
||||
if previous_hidden_states.size(0) == 0 or \
|
||||
previous_hidden_states.size(0) != batch_size:
|
||||
hidden_dim = self.config.model.hidden_size
|
||||
device = inputs_embeds.device
|
||||
# Create zero tensor with matching batch size
|
||||
previous_hidden_states = \
|
||||
torch.zeros(batch_size, hidden_dim, device=device)
|
||||
|
||||
if self.add_para_norm:
|
||||
inputs_embeds = torch.cat([
|
||||
self.enorm(inputs_embeds),
|
||||
self.hnorm(previous_hidden_states)
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
|
||||
dim=-1)
|
||||
|
||||
inputs_embeds = self.fc(inputs_embeds)
|
||||
|
||||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||
|
||||
hidden_states = self.model.model(
|
||||
input_ids=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
if self.token_map is not None:
|
||||
_logits = logits
|
||||
logits = -torch.inf * torch.ones(
|
||||
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||
device=_logits.device,
|
||||
dtype=_logits.dtype)
|
||||
|
||||
logits[..., self.token_map] = _logits
|
||||
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# due to missing lm_head weights and its config being that of a
|
||||
# Llama model. Here's a compatible version with the same weights:
|
||||
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
|
||||
# Also, here's an example script for converting trained EAGLE
|
||||
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if name == "token_map":
|
||||
if self.config.truncated_vocab_size < self.config.vocab_size:
|
||||
self.token_map = nn.Parameter(loaded_weight,
|
||||
requires_grad=False)
|
||||
elif name.startswith("fc.weight"):
|
||||
weight_loader = getattr(self.fc.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.weight, loaded_weight)
|
||||
elif name.startswith("fc.bias"):
|
||||
if self.fc.bias is not None:
|
||||
weight_loader = getattr(self.fc.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.bias, loaded_weight)
|
||||
else:
|
||||
logger.warning_once("Found bias in the loaded weights but "
|
||||
"the model config doesn't have bias.")
|
||||
elif name.startswith("enorm.weight"):
|
||||
weight_loader = getattr(self.enorm.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.enorm.weight, loaded_weight)
|
||||
elif name.startswith("hnorm.weight"):
|
||||
weight_loader = getattr(self.hnorm.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.hnorm.weight, loaded_weight)
|
||||
elif name.startswith("model.lm_head.") or name.startswith(
|
||||
"model.model."):
|
||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||
elif name.startswith("lm_head.") or name.startswith("model."):
|
||||
model_weights[name] = loaded_weight
|
||||
else:
|
||||
model_weights[f"model.{name}"] = loaded_weight
|
||||
|
||||
if "lm_head.weight" in model_weights:
|
||||
lm_head_weight = model_weights.pop("lm_head.weight")
|
||||
|
||||
if self.token_map is not None and\
|
||||
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
||||
|
||||
lm_head_weight = lm_head_weight[self.token_map]
|
||||
|
||||
else:
|
||||
# NOTE(Shangming): initialize the placeholder for lm_head weight.
|
||||
lm_head_weight = torch.zeros(
|
||||
self.lm_head.org_vocab_size,
|
||||
self.lm_head.embedding_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.lm_head.weight, lm_head_weight)
|
||||
|
||||
self.model.load_weights(model_weights.items())
|
||||
@ -239,14 +239,15 @@ _MULTIMODAL_MODELS = {
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
# Temporarily disabled.
|
||||
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
||||
_TRANSFORMERS_MODELS = {
|
||||
|
||||
@ -132,14 +132,10 @@ class CudaPlatformBase(Platform):
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
|
||||
@ -326,15 +326,10 @@ class RocmPlatform(Platform):
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
if envs.VLLM_USE_V1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not yet supported on vLLM V1."
|
||||
)
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
|
||||
@ -112,13 +112,6 @@ class RequestMetrics:
|
||||
model_execute_time: The time spent in the model execute function. This
|
||||
will include model forward, block/sync across
|
||||
workers, cpu-gpu sync time and sampling time.
|
||||
spec_token_acceptance_counts: number of accepted speculative tokens at
|
||||
each position; the first token is from
|
||||
the target model and is always accepted;
|
||||
e.g., when it's [10, 8, 4, 2] for a req,
|
||||
it means there were 10 forward passes in
|
||||
total, and there were 8, 4, 2 accepted
|
||||
tokens at 1st, 2nd, 3rd speculation step.
|
||||
"""
|
||||
arrival_time: float
|
||||
last_token_time: float
|
||||
@ -129,7 +122,6 @@ class RequestMetrics:
|
||||
scheduler_time: Optional[float] = None
|
||||
model_forward_time: Optional[float] = None
|
||||
model_execute_time: Optional[float] = None
|
||||
spec_token_acceptance_counts: Optional[list[int]] = None
|
||||
|
||||
|
||||
class SequenceDataDelta(
|
||||
@ -748,9 +740,7 @@ class SequenceGroup:
|
||||
last_token_time=arrival_time,
|
||||
first_scheduled_time=None,
|
||||
first_token_time=None,
|
||||
time_in_queue=None,
|
||||
spec_token_acceptance_counts=[0] *
|
||||
draft_size)
|
||||
time_in_queue=None)
|
||||
self.last_token_latency = 0.0
|
||||
self.lora_request = lora_request
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
@ -1390,8 +1380,6 @@ class ExecuteModelRequest(
|
||||
previous_hidden_states: Optional[HiddenStates] = None
|
||||
# The number of forward steps to run.
|
||||
num_steps: int = 1
|
||||
# The step index for spec model input.
|
||||
spec_step_idx: Optional[int] = None
|
||||
# Finished request ids since last step.
|
||||
finished_requests_ids: list[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
|
||||
@ -1,506 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
probabilities of speculative tokens according to the scoring model.
|
||||
|
||||
Batch expansion converts a list of sequences and multiple query positions
|
||||
to a new batch of sequences, each with a single query position. This allows
|
||||
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||
kernel.
|
||||
|
||||
It is strictly less efficient than MQA scoring.
|
||||
|
||||
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
# TODO(cade) perform this on GPU to remove blocking call.
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
return self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
return self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[List[TokenId]],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
proposal tokens, create a new batch where each sequence has a single
|
||||
query token.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
|
||||
split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
|
||||
spec_expanded_seqs = self._create_scoring_model_input(
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(spec_expanded_seqs)
|
||||
# Batch speculative and non-speculative (e.g. chunked prefill) requests
|
||||
# but make sure order is prefill|decode due to backend requirement.
|
||||
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_non_speculative(
|
||||
self, scores: SpeculativeScores,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||
has_prompt_log: bool) -> SpeculativeScores:
|
||||
"""
|
||||
Augment input `scores` with non-speculative requests outputs.
|
||||
This includes decode requests with speculation turned off, as well
|
||||
as prefill requests when `enable_chunked_prefill` is set.
|
||||
For the latter, prefills are further separated into terminal and
|
||||
non-terminal chunks (from which no token is sampled).
|
||||
"""
|
||||
if not non_spec_indices:
|
||||
return scores
|
||||
|
||||
if has_prompt_log:
|
||||
# When prompt_logprobs is enabled, prefills yield output token
|
||||
# (and respective prob) in the last entry (prompt|out):
|
||||
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||
# With chunked prefill, non-terminal chunks have -1 on each
|
||||
# position: they're still picked, but they're discarded later.
|
||||
seq_meta = seq_group_metadata_list
|
||||
nospec_sizes = torch.tensor([
|
||||
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||
for i in non_spec_indices
|
||||
])
|
||||
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||
else:
|
||||
# In this case only sampled tokens are returned, select all.
|
||||
nospec_sampled_token_idxs = list(
|
||||
range(len(non_spec_outputs.token_ids)))
|
||||
|
||||
scores.token_ids[non_spec_indices, :1] = \
|
||||
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.probs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
if scores.hidden_states is not None:
|
||||
assert non_spec_outputs.hidden_states is not None
|
||||
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
return scores
|
||||
|
||||
def _contract_batch(
|
||||
self,
|
||||
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
contracted_bs = len(contracted_seq_group_metadata_list)
|
||||
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs,
|
||||
non_spec_target_hidden_states) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences, prefill chunks with no out tokens included
|
||||
non_spec_expanded_bs = len(non_spec_indices)
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.reshape(*target_token_ids.shape,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
||||
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
all_hidden_states = target_hidden_states.new_zeros(
|
||||
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in contracted_seq_group_metadata_list)
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is.
|
||||
prompt_logprobs = None
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
elif not has_prompt_log:
|
||||
# When prompt logprobs are not to be returned,
|
||||
# we can ignore non-terminal chunks (no out token).
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if contracted_seq_group_metadata_list[idx].do_sample
|
||||
]
|
||||
|
||||
# "Contract" speculative.
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
if all_hidden_states is not None:
|
||||
all_hidden_states[spec_indices] = target_hidden_states
|
||||
|
||||
spec_scores = SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=all_hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
non_spec_outputs = SpeculativeScores(
|
||||
probs=non_spec_target_probs,
|
||||
token_ids=non_spec_target_token_ids,
|
||||
logprobs=non_spec_target_logprobs,
|
||||
hidden_states=non_spec_target_hidden_states)
|
||||
# Contract remaining nonspec entries based on non_spec_indices, if any.
|
||||
return self._contract_non_speculative(
|
||||
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
|
||||
non_spec_outputs, has_prompt_log)
|
||||
|
||||
def _contract_batch_all_spec(
|
||||
self,
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
It assumes all sequences in the batch were previously expanded.
|
||||
"""
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
contracted_bs, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# Reshape tensors to original batch size
|
||||
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
|
||||
contracted_bs, k + 1)
|
||||
target_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||
*target_token_ids.shape, self._vocab_size)
|
||||
target_logprobs = target_sampler_output.logprobs.reshape(
|
||||
target_probs.shape)
|
||||
target_hidden_states = target_sampler_output.hidden_states
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
return SpeculativeScores(probs=target_probs,
|
||||
token_ids=target_token_ids,
|
||||
logprobs=target_logprobs,
|
||||
hidden_states=target_hidden_states,
|
||||
prompt_logprobs=None)
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
|
||||
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||
to the seq id in the original batch.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
seq_group_metadata,
|
||||
proposal_token_ids,
|
||||
i,
|
||||
target_seq_ids_iter,
|
||||
) for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list)))
|
||||
|
||||
return target_seq_group_metadata
|
||||
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given an input sequence group metadata and a list of draft tokens,
|
||||
create a list of target SequenceGroupMetadata, one for each
|
||||
token id that needs to be scored.
|
||||
|
||||
Naive speculative decoding requires K target model scores, one for each
|
||||
draft model token. However one can add a bonus token such that if each
|
||||
token is accepted, then a final token may be sampled from the model.
|
||||
This function creates K+1 target SequenceGroupMetadata to take
|
||||
advantage of the bonus token.
|
||||
"""
|
||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||
"Beam search "
|
||||
"not supported in speculative decoding")
|
||||
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
|
||||
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
sampling_params = input_seq_group_metadata.sampling_params
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for i, token_ids in enumerate(token_ids_to_score):
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
sampling_params=sampling_params,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
@staticmethod
|
||||
def _create_single_target_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
sampling_params: SamplingParams,
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
Args:
|
||||
seq_group_metadata: The metadata for the input sequence.
|
||||
seq_id: The input sequence ID.
|
||||
target_seq_id: The corresponding target sequence ID.
|
||||
token_ids: The list of token ids that are to be appended to the
|
||||
input sequence.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_token_ids = seq_data.prompt_token_ids_array
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
mrope_position_delta = seq_data.mrope_position_delta
|
||||
|
||||
new_seq_data_dict = {
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids,
|
||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
new_output_token_ids),
|
||||
),
|
||||
}
|
||||
# This is a hack. Technically, spec decoding should compute
|
||||
# num_lookahead slots at one shot, but instead, it expands the batch
|
||||
# and evaluate one by one right now. context_len is seq_len - 1 because
|
||||
# the kv cache is filled by a previous batch in the batch expansion.
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _split_scoring_output(
|
||||
sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
#
|
||||
# First samples are non-speculative, latter samples are from speculative
|
||||
# scoring (prefill|decode order).
|
||||
split_sizes = (sampler_output.sampled_token_ids.numel() -
|
||||
num_scoring_tokens, num_scoring_tokens)
|
||||
(non_spec_probs,
|
||||
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(non_spec_sampled_tokens, spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(non_spec_logprobs,
|
||||
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
if sampler_output.hidden_states is not None:
|
||||
(non_spec_hidden_states, spec_hidden_states
|
||||
) = sampler_output.hidden_states.split(split_sizes)
|
||||
else:
|
||||
non_spec_hidden_states, spec_hidden_states = None, None
|
||||
|
||||
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||
non_spec_logprobs, non_spec_hidden_states)
|
||||
|
||||
@staticmethod
|
||||
def _create_target_seq_id_iterator(
|
||||
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
"""Create an iterator for creating target sequence ids.
|
||||
Target sequence ids are distinct from sequence ids because we create a
|
||||
distinct target sequence id for each proposal token to be scored.
|
||||
|
||||
This implementation increments a counter starting at 1 + max of all
|
||||
provided input sequence ids.
|
||||
"""
|
||||
return count(start=max(seq_ids) + 1)
|
||||
|
||||
@staticmethod
|
||||
def _get_token_ids_to_score(
|
||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||
) -> List[List[TokenId]]:
|
||||
"""Given an int tensor of proposal token ids, return a list of
|
||||
token ids that should be scored.
|
||||
|
||||
Returns k+1 output lists. The additional one is used for generating the
|
||||
bonus token.
|
||||
|
||||
Example:
|
||||
Input: [0, 1, 2, 3] (k=4)
|
||||
Output: (k+1 lists)
|
||||
[]
|
||||
[0]
|
||||
[0, 1]
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids)))
|
||||
return token_ids_to_score
|
||||
@ -1,349 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
|
||||
try:
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# vllm_flash_attn is not installed, try the ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
except (ModuleNotFoundError, ImportError) as err:
|
||||
raise RuntimeError(
|
||||
"Draft model speculative decoding currently only supports "
|
||||
"CUDA and ROCm flash attention backend.") from err
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A flag to enable debug prints for the updated input tensors
|
||||
# before each step.
|
||||
debug_advance_input = False
|
||||
# A flag to allow GPU advance step for draft model runner.
|
||||
# Set to False for debugging.
|
||||
allow_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
Since the draft model always execute k forward passes consecutively to
|
||||
generate k speculative tokens in a single speculative decoding step,
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
TODOs:
|
||||
1. Currently supports only flash-attn, add support for other attn_backends.
|
||||
2. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
super().__init__(model_runner)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
|
||||
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
|
||||
last_output: SamplerOutput) -> ModelRunnerInputBase:
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
# Get num_seqs
|
||||
num_seqs = len(model_input.seq_lens)
|
||||
num_queries = len(model_input.query_lens)
|
||||
|
||||
# Get output tokens GPU tensor
|
||||
sampled_token_ids = last_output.sampled_token_ids
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
attn_metadata.advance_step(model_input, sampled_token_ids,
|
||||
self.block_size, num_seqs, num_queries)
|
||||
|
||||
# Update sampling_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
self._update_sampling_metadata(sampling_metadata, num_seqs,
|
||||
num_queries)
|
||||
|
||||
# Create new input
|
||||
new_model_input = self._model_input_cls(
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=model_input.query_lens,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
is_prompt=False,
|
||||
)
|
||||
|
||||
# Ensure we skip CPU samples
|
||||
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
|
||||
# We can reuse sampling tensors since every decode iteration is the same
|
||||
new_model_input.sampling_metadata.reuse_sampling_tensors = True
|
||||
|
||||
if debug_advance_input:
|
||||
logger.debug("NEW INPUT: ")
|
||||
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
|
||||
logger.debug(" input_positions = %s",
|
||||
new_model_input.input_positions)
|
||||
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
|
||||
logger.debug(" query_lens = %d", new_model_input.query_lens)
|
||||
logger.debug(" attn_metadata:")
|
||||
logger.debug(" seq_lens_tensor: %s",
|
||||
attn_metadata.seq_lens_tensor)
|
||||
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
|
||||
logger.debug(" block_tables: %s", attn_metadata.block_tables)
|
||||
|
||||
return new_model_input
|
||||
|
||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||
Currently required conditions are:
|
||||
1. Only decodes
|
||||
2. Only flash-attn
|
||||
3. No LORA
|
||||
4. No prompt_adapter_config
|
||||
"""
|
||||
if not allow_gpu_advance_step:
|
||||
return False
|
||||
|
||||
# We allow multi-step GPU only in decode mode
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
if seq_group.is_prompt:
|
||||
return False
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
if self.lora_config:
|
||||
return False
|
||||
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
return not self.prompt_adapter_config
|
||||
|
||||
def set_indices_of_seq_with_bonus_tokens(self,
|
||||
indices_of_seq_with_bonus_tokens):
|
||||
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelRunnerInputBase,
|
||||
kv_caches: List[torch.Tensor],
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
|
||||
Optimizations used:
|
||||
1. Input tensors are updated on the GPU directly
|
||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||
them since we do batch expansion later that uses GPU outputs)
|
||||
3. Reuses sampling tensors (since we run only decodes and they have
|
||||
a repeating sampling logic)
|
||||
"""
|
||||
|
||||
# When num_steps == 1, we execute the fallback here for the GPU
|
||||
# advance_step, which runs prepare_inputs on CPU and for each spec
|
||||
# iteration invokes this function only once
|
||||
# (Look at multi-step-worker code)
|
||||
is_fallback = num_steps == 1
|
||||
if not is_fallback:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
# Sanity
|
||||
if self.lora_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for LORA")
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.inputs_embeds is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"inputs_embeds")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
)
|
||||
else:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
use_cuda_graph = False
|
||||
if model_input.attn_metadata.num_prefills > 0:
|
||||
# In this case, execute_model(..) was called directly
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"execute_model(..) of draft_model_runner can be called "
|
||||
"directly only with a single-step prefill")
|
||||
else:
|
||||
# We can skip CPU samples for spec token generation.
|
||||
# (We do allow CPU samples for num_steps == 1 to support the
|
||||
# fallback case, where supports_gpu_multi_step(..) does not pass)
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
not is_fallback)
|
||||
|
||||
# Attn attr defines if we use cuda graphs
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
||||
|
||||
# Get model
|
||||
if use_cuda_graph:
|
||||
if model_input.inputs_embeds is None:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
|
||||
if previous_hidden_states is not None:
|
||||
hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
hidden_states = None
|
||||
else:
|
||||
model_executable = self.model
|
||||
hidden_states = previous_hidden_states
|
||||
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
compute_logits_kwargs = {}
|
||||
# Run model
|
||||
if hasattr(self.model.config, "num_nextn_predict_layers"):
|
||||
# for DeepSeek MTP only to use the corresponding layer for
|
||||
# each step
|
||||
spec_step_idx = kwargs.get("spec_step_idx", step)
|
||||
model_execute_kwargs["spec_step_idx"] = spec_step_idx
|
||||
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=None,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
**model_execute_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata,
|
||||
**compute_logits_kwargs)
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
# Sample the next token.
|
||||
output = self.model_runner.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
if self.return_hidden_states and is_fallback:
|
||||
if use_cuda_graph:
|
||||
indices = model_input.sampling_metadata\
|
||||
.selected_token_indices
|
||||
output.hidden_states = hidden_states[:len(indices)]
|
||||
else:
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||
assert output.sampled_token_ids is not None
|
||||
# output.sampled_token_ids should be of shape (num_seqs, 1)
|
||||
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
|
||||
assert num_tokens_per_seq == 1
|
||||
count = 0
|
||||
for i in range(nums_seqs):
|
||||
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
|
||||
count]
|
||||
if i != bonus_seq_idx:
|
||||
# The following might cause a cpu->gpu sync
|
||||
# However, the performance impact is negligible as we
|
||||
# benchmarked on H100.
|
||||
output.sampled_token_ids[
|
||||
i, :] = model_input.input_tokens[bonus_seq_idx]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
model_input = self._gpu_advance_step(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
||||
@ -1,99 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeProposals:
|
||||
"""Datastructure used to represent proposal tokens from some proposer. It
|
||||
also tracks how many speculative tokens each sequence has.
|
||||
"""
|
||||
|
||||
# Speculative proposal tokens.
|
||||
proposal_token_ids: torch.Tensor
|
||||
|
||||
# Probabilities of the proposal tokens according to the proposer.
|
||||
proposal_probs: torch.Tensor
|
||||
|
||||
# The valid length of each proposal; can be zero.
|
||||
proposal_lens: torch.Tensor
|
||||
|
||||
# A flag to mark that there's no available proposals
|
||||
no_proposals: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeScores:
|
||||
"""Datastructure used to represent the scores of speculative tokens
|
||||
according to the scoring model.
|
||||
"""
|
||||
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Log-probabilities of the speculative tokens according to the scoring
|
||||
# model. These values can be used to generate Logprob objects that are
|
||||
# returned to the user.
|
||||
logprobs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
# Optional last hidden states from the scoring model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Scoring model may also return logprobs for prompt tokens
|
||||
# for each request, when chunked prefill is enabled.
|
||||
prompt_logprobs: Optional[List[PromptLogprobs]] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
f"token_ids={self.token_ids.shape})")
|
||||
|
||||
|
||||
class SpeculativeProposer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# If set, this contains all sequence IDs that were assigned
|
||||
# bonus tokens in their last forward pass.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase,
|
||||
device: Union[torch.device, str], vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
@ -1,138 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
|
||||
"""Worker for Medusa.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
self.worker.init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For medusa worker, this indicator shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
seq_lens, query_lens = self._prepare_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
hidden_states,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
return model_outputs, False
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return [], []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
seq_lens.append(seq_len)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
query_lens.append(1)
|
||||
|
||||
return seq_lens, query_lens
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MedusaWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support beam search.")
|
||||
@ -1,213 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class SpecDecodeWorkerMetrics(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||
"""
|
||||
|
||||
# The empirical acceptance rate of the proposal method on a per-token basis.
|
||||
# This is useful for evaluating how well the proposal method aligns with the
|
||||
# scoring method.
|
||||
draft_acceptance_rate: float
|
||||
|
||||
# The empirical efficiency, measured as the number of tokens emitted by the
|
||||
# system divided by the number of tokens that could be emitted by the system
|
||||
# if the proposal method were perfect.
|
||||
system_efficiency: float
|
||||
|
||||
# The number of speculative tokens produced by the proposal method.
|
||||
draft_tokens: int
|
||||
|
||||
# The number of tokens emitted by the entire system.
|
||||
emitted_tokens: int
|
||||
|
||||
# The number of tokens accepted by the scoring model and verification
|
||||
# routine, e.g. Llama2-70B and lossless rejection sampling.
|
||||
#
|
||||
# NOTE: Any token accepted by the verification routine is considered
|
||||
# accepted (regardless of if the speculative prefix is also accepted). The
|
||||
# user will usually see less accepted tokens. This metric is helpful when
|
||||
# evaluating alignment of the proposal method with the scoring model.
|
||||
accepted_tokens: int
|
||||
|
||||
# The number of speculative tokens per sequence.
|
||||
num_spec_tokens: int
|
||||
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
|
||||
# We don't have a device set yet.
|
||||
self._copy_stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_draft_tokens = 0
|
||||
|
||||
self._rejsample_metrics_collect_interval_s = collect_interval_s
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
self._rank = rank
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
|
||||
def init_tensors(self,
|
||||
rank: int,
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
stream = current_platform.Stream
|
||||
if stream is not None:
|
||||
self._copy_stream = stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
# Skip for any platform that doesn't have device Event
|
||||
if current_platform.Event is None:
|
||||
return None
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
return False
|
||||
|
||||
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a device event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(current_platform.current_stream())
|
||||
|
||||
with current_platform.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = current_platform.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
def _collect_rejsample_metrics(
|
||||
self, k: int,
|
||||
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
|
||||
"""Create metrics object from statistics copied asynchronously.
|
||||
|
||||
Args:
|
||||
k: int. The number of speculative tokens; used to determine system
|
||||
efficiency.
|
||||
ready_event: torch.cuda.Event. The CUDA event recording when the
|
||||
async GPU->CPU copy is complete.
|
||||
"""
|
||||
|
||||
ready_event.synchronize()
|
||||
|
||||
# update time of last collection
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||
draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if max_num_emitted_tokens > 0:
|
||||
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
return SpecDecodeWorkerMetrics(
|
||||
num_spec_tokens=k,
|
||||
draft_acceptance_rate=draft_acceptance_rate,
|
||||
system_efficiency=system_efficiency,
|
||||
accepted_tokens=accepted_tokens,
|
||||
draft_tokens=draft_tokens,
|
||||
emitted_tokens=emitted_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||
accepted.
|
||||
|
||||
This is equal to the number of sequences that have been speculated on,
|
||||
times (speculation len + 1). The +1 comes from the bonus token.
|
||||
"""
|
||||
# Determine the number of sequences that have been speculated on. Since
|
||||
# the batch size can be variable, we divide by k.
|
||||
assert draft_tokens % k == 0
|
||||
total_num_spec_seqs = draft_tokens // k
|
||||
|
||||
# A single sequence may emit k accepted tokens and one bonus token in
|
||||
# the best case.
|
||||
num_emitted_per_seq_if_all_accepted = k + 1
|
||||
|
||||
# The max num of emitted tokens is the number of speculated sequences
|
||||
# times the max emitted per seq.
|
||||
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||
@ -1,94 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
|
||||
|
||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
"""Worker for MLPSpeculator models.
|
||||
|
||||
Not currently compatible with LoRA or chunked prefill.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For mlp spec worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
input_ids=input_tokens,
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
hidden_states,
|
||||
num_predict_tokens=sample_len,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert len(model_outputs) == sample_len
|
||||
|
||||
return model_outputs, True
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return torch.empty(0, device=self.device), [], []
|
||||
|
||||
input_tokens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
input_tokens.extend(tokens)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
input_tokens.append(seq_data.get_last_token_id())
|
||||
query_lens.append(1)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
return input_tokens_tensor, seq_lens, query_lens
|
||||
@ -1,160 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
|
||||
|
||||
class MQAScorer(SpeculativeScorer):
|
||||
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
target_seq_group_metadata_list = []
|
||||
target_seq_id_start = max(
|
||||
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||
for i, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
if all_proposal_lengths[i] == 0:
|
||||
# Keep prompt seqs untouched (keep computed_tokens for chunks).
|
||||
target_seq_group_metadata_list.append(seq_group_metadata)
|
||||
continue
|
||||
|
||||
seq_data_dict = seq_group_metadata.seq_data
|
||||
assert len(seq_data_dict) == 1
|
||||
seq_id = next(iter(seq_data_dict.keys()))
|
||||
|
||||
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
output_token_ids = seq_data.get_output_token_ids()
|
||||
proposal_token_ids = all_proposal_tokens[
|
||||
i][:all_proposal_lengths[i]]
|
||||
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||
|
||||
target_seq_id = target_seq_id_start + i
|
||||
new_seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
)
|
||||
new_seq_data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||
|
||||
# Ensure that the new decode sequence has at least one token.
|
||||
assert len(output_token_ids) >= 1
|
||||
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||
|
||||
new_seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
)
|
||||
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
k = execute_model_req.num_lookahead_slots
|
||||
bs = len(execute_model_req.seq_group_metadata_list)
|
||||
target_token_ids = target_sampler_output.sampled_token_ids
|
||||
target_probs = target_sampler_output.sampled_token_probs
|
||||
target_logprobs = target_sampler_output.logprobs
|
||||
prompt_logprobs = None
|
||||
|
||||
# If all requests have the same number of query tokens, we can avoid
|
||||
# the for loop to build output for better performance.
|
||||
if min(all_proposal_lengths) == k:
|
||||
# Regular decodes only.
|
||||
assert all(not sg.is_prompt
|
||||
for sg in target_seq_group_metadata_list
|
||||
if sg.is_prompt)
|
||||
bs, _ = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||
else:
|
||||
# We either have decodes with different lens or prefill+decodes.
|
||||
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||
self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
target_token_ids = target_token_ids.flatten()
|
||||
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is, since it may be
|
||||
# that n_prompts >> K.
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in target_seq_group_metadata_list)
|
||||
# TODO (NickLucche) we should surface `disable_logprobs` as to not
|
||||
# break abstraction to get its value.
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
|
||||
# Split loop into prefill|decode for readability.
|
||||
start_loc, i = 0, 0
|
||||
while i < len(target_seq_group_metadata_list
|
||||
) and target_seq_group_metadata_list[i].is_prompt:
|
||||
seq_meta = target_seq_group_metadata_list[i]
|
||||
end_loc = start_loc
|
||||
if has_prompt_log:
|
||||
end_loc += seq_meta.token_chunk_size
|
||||
elif seq_meta.do_sample:
|
||||
end_loc += 1
|
||||
|
||||
# Skip chunks with no output tokens.
|
||||
if seq_meta.do_sample:
|
||||
# Get sampled token (last position in chunk) and its prob.
|
||||
all_tokens[i, 0] = target_token_ids[end_loc - 1]
|
||||
all_probs[i, 0] = target_probs[end_loc - 1]
|
||||
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
|
||||
|
||||
i += 1
|
||||
start_loc = end_loc
|
||||
# Decodes.
|
||||
while i < len(target_seq_group_metadata_list):
|
||||
proposed_len, seq_meta = all_proposal_lengths[
|
||||
i], target_seq_group_metadata_list[i]
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
i += 1
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||
bs, (k + 1), -1)
|
||||
|
||||
return SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
@ -1,423 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: SpeculativeProposer
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||
self.model_runner.sampler.include_gpu_probs_tensor = True
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
|
||||
) = True
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if current_platform.is_cuda_alike() and isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(
|
||||
model_output, expanded_request)
|
||||
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
@staticmethod
|
||||
def _maybe_update_previous_hidden_states(
|
||||
model_output: SamplerOutput,
|
||||
expanded_request: ExecuteModelRequest) -> None:
|
||||
"""
|
||||
Updates the previous hidden states in an expanded request
|
||||
in-place with the hidden states from the model output.
|
||||
"""
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
expanded_request.previous_hidden_states = HiddenStates(
|
||||
model_output.hidden_states,
|
||||
expanded_request.seq_group_metadata_list)
|
||||
|
||||
@staticmethod
|
||||
def _expand_execute_model_request(
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_with_bonus_token_in_last_step: set,
|
||||
) -> Tuple[ExecuteModelRequest, List[int]]:
|
||||
"""
|
||||
Expands the execute model request based on sequences with bonus
|
||||
tokens.
|
||||
|
||||
For each sequence with a bonus token, this method creates a new
|
||||
sequence without the bonus token and adds it to the execute model
|
||||
request. The original sequence groups are also retained. The indices
|
||||
of the original sequence groups are returned for further processing.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The original execute
|
||||
model request.
|
||||
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
|
||||
contain bonus tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
|
||||
request with expanded sequences and a list of indices corresponding
|
||||
to the original sequence groups.
|
||||
"""
|
||||
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
updated_execute_model_req = execute_model_req.clone(
|
||||
updated_seq_group_metadata_list)
|
||||
indices_of_original_sequence_groups = []
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
seq_group_has_bonus_tokens = False
|
||||
for seq_id, _ in seq_group.seq_data.items():
|
||||
# Identify sequences with bonus tokens in the sequence group.
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
seq_group_has_bonus_tokens = True
|
||||
break
|
||||
if seq_group_has_bonus_tokens:
|
||||
#Create new sequences without the last bonus token. These new
|
||||
# sequence have the same sequence id as the original sequence.
|
||||
# We create a new sequence group and add them there.
|
||||
updated_seq_group_without_bonus_token = \
|
||||
MultiStepWorker._copy_seq_metadata_excluding_last_token(
|
||||
seq_group, seq_with_bonus_token_in_last_step)
|
||||
updated_seq_group_metadata_list.append(
|
||||
updated_seq_group_without_bonus_token)
|
||||
# Add the original sequence group.
|
||||
updated_seq_group_metadata_list.append(
|
||||
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
|
||||
# Record the index of the original sequence group.
|
||||
indices_of_original_sequence_groups.append(
|
||||
len(updated_seq_group_metadata_list) - 1)
|
||||
|
||||
updated_execute_model_req.seq_group_metadata_list =\
|
||||
updated_seq_group_metadata_list
|
||||
|
||||
if isinstance(updated_execute_model_req.previous_hidden_states,
|
||||
HiddenStates):
|
||||
updated_execute_model_req.previous_hidden_states\
|
||||
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
|
||||
|
||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||
|
||||
@staticmethod
|
||||
def _filter_model_output(
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||
outputs to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[output_indices_to_retain]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: set,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
model_output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
indices_of_seq_with_bonus_tokens: List[int]) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
count = 0
|
||||
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
|
||||
zip(seq_group_metadata_list, model_output)):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
# Determine the actual token ID to be generated,
|
||||
# considering bonus tokens
|
||||
if index != indices_of_seq_with_bonus_tokens[count]:
|
||||
bonus_seq_metadata = seq_group_metadata_list[
|
||||
indices_of_seq_with_bonus_tokens[count]]
|
||||
_, bonus_token_seq_data = next(
|
||||
iter(bonus_seq_metadata.seq_data.items()))
|
||||
token_id = bonus_token_seq_data.output_token_ids[-1]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob,
|
||||
seq_output.output_embed)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
def _shallow_copy_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
Helpful when the vLLM scheduler runs in the same process as the worker.
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:]
|
||||
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
@staticmethod
|
||||
def _copy_seq_metadata_excluding_last_token(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_ids_to_copy: Set[int],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""
|
||||
Creates a shallow copy of the given SequenceGroupMetadata, retaining
|
||||
only the sequence IDs specified in seq_ids_to_copy. For each of these
|
||||
sequence IDs, all output_token_ids except the last one are copied.
|
||||
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
|
||||
|
||||
Parameters:
|
||||
seq_group_metadata (SequenceGroupMetadata): The original sequence
|
||||
group metadata.
|
||||
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
|
||||
copy.
|
||||
|
||||
Returns:
|
||||
SequenceGroupMetadata: A shallow copy of the sequence group metadata
|
||||
with the specified modifications.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
# Shallow-copy seq_data and modify the output_token_ids.
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
if (seq_id in seq_ids_to_copy):
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
# Copy all the output token ids except the last.
|
||||
# Also reduce num_computed_tokens by 1 since we are not
|
||||
# including the last output token.
|
||||
# NOTE: num_computed_tokens is not directly used by the
|
||||
# speculative decoding workers, as it is only relevant for
|
||||
# chunked prefill, which is disabled for speculative decoding.
|
||||
# However, to maintain consistency in num_computed_tokens,
|
||||
# we update it here.
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:-1]
|
||||
new_seq_data[seq_id].update_num_computed_tokens(-1)
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
weight_loader = getattr(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
"weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
lm_head_weight)
|
||||
@ -1,196 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
|
||||
class _DummyModel(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class NGramWorker(NonLLMProposerWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implements prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenarios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
device_type: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = local_rank
|
||||
self.device_type = device_type
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
|
||||
ngram_prompt_lookup_max: int):
|
||||
# Search valid candidate window between
|
||||
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
||||
|
||||
# Current NGramWorker only supports Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
pass # Dummy
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return _DummyModel()
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
has_spec_out = False
|
||||
token_id_list: List[Optional[torch.Tensor]] = []
|
||||
token_prob_list: List[Optional[torch.Tensor]] = []
|
||||
for idx, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
# When seq_len is less than 3072 (3K), we use CPU to perform
|
||||
# the ngram match. Otherwise, we use the device specified in
|
||||
# the model config (normally GPU). 3072 is a rough threshold
|
||||
# based on profiling on H100, and it can be adjusted based
|
||||
# on the actual performance on different hardware.
|
||||
cur_device = "cpu" if seq_len < 3072 else self.device
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=cur_device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min - 1,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-ngram_size:]
|
||||
if ngram_size == 1:
|
||||
# Do not match itself and do not use unfold and all
|
||||
matches = (input_ids[:-1] == ngram_tensor)
|
||||
else:
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
# Do not match itself
|
||||
matches = (windows[:-1] == ngram_tensor).all(dim=-1)
|
||||
|
||||
# first_match includes "values" (bool), indicating whether
|
||||
# the match is found, and "indices", indicating the index
|
||||
# of the first match.
|
||||
first_match = matches.max(dim=-1)
|
||||
if first_match.values.item():
|
||||
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||
spec_indices = (
|
||||
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||
sample_len, device=cur_device)
|
||||
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||
res = input_ids.gather(dim=-1,
|
||||
index=spec_indices).to(self.device)
|
||||
token_id_list.append(res)
|
||||
token_prob_list.append(
|
||||
torch.nn.functional.one_hot(
|
||||
res,
|
||||
num_classes=self.vocab_size).to(torch.float32))
|
||||
has_spec_out = True
|
||||
break
|
||||
else:
|
||||
token_id_list.append(None)
|
||||
token_prob_list.append(None)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs: List[Optional[SamplerOutput]] = []
|
||||
for idx in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
if token_id_list[idx] is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_prob_list[idx],
|
||||
logprobs=torch.zeros((sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device),
|
||||
sampled_token_ids=token_id_list[idx],
|
||||
))
|
||||
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
@ -1,59 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
|
||||
|
||||
|
||||
class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer):
|
||||
"""Interface for proposer workers"""
|
||||
|
||||
@abstractmethod
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# A set containing all sequence IDs that were assigned bonus tokens
|
||||
# in their last forward pass. This set is used to backfill the KV cache
|
||||
# with the key-value pairs of the penultimate token in the sequences.
|
||||
# This parameter is only used by the MultiStepWorker, which relies on
|
||||
# the KV cache for token generation. It is not used by workers that
|
||||
# do not utilize the KV cache.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int]
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
||||
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||
"""Proposer worker which does not use a model with kvcache"""
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""get_spec_proposals is used to get the proposals"""
|
||||
return []
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""This is never called on the proposer, only the target model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return 0
|
||||
@ -1,196 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _DummyModel(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
"""Class which allows a speculative draft model to run with smaller tensor
|
||||
parallel degree than target model.
|
||||
This reduces the communication overhead of small draft models.
|
||||
|
||||
To implement this feature, this class differs behavior based on is_dummy
|
||||
flag, where dummy means worker that does not participate draft generation.
|
||||
Participating workers use a smaller tp group by patching vLLM's tensor
|
||||
parallel group temporarily during forward passes of draft models.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
|
||||
target_tensor_parallel_size: int):
|
||||
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
|
||||
"""
|
||||
if draft_tensor_parallel_size == target_tensor_parallel_size:
|
||||
return worker
|
||||
|
||||
# gpu ranks that will generate draft tokens together
|
||||
draft_ranks = list(range(draft_tensor_parallel_size))
|
||||
|
||||
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
|
||||
return cls(worker, draft_ranks)
|
||||
|
||||
def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
|
||||
"""Create a SmallerTpProposerWorker.
|
||||
|
||||
Args:
|
||||
worker (~vllm.spec_decode.multi_step_worker.MultiStepWorker): an
|
||||
actual worker wrapped with this class
|
||||
draft_ranks (List[int]): if this value is given, only the GPU ranks
|
||||
written in this value participate in draft generation
|
||||
"""
|
||||
self._worker = worker
|
||||
self._draft_ranks = draft_ranks
|
||||
|
||||
# init during init_device
|
||||
self._is_dummy = False
|
||||
self._tp_group = None
|
||||
|
||||
def _patch_tensor_parallel_group(self):
|
||||
"""Temporarily patch the global tp group state with its own tp group
|
||||
state.
|
||||
"""
|
||||
return patch_tensor_parallel_group(self._tp_group)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self._is_dummy = get_tp_group().rank not in self._draft_ranks
|
||||
|
||||
# dummy workers do nothing
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# creates tp process group containing only a subset of gpu ranks
|
||||
local_rank = get_tp_group().local_rank
|
||||
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
|
||||
self._tp_group = init_model_parallel_group([self._draft_ranks],
|
||||
local_rank, tp_backend)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.init_device()
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self._worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
self._worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
if self._is_dummy:
|
||||
# this case is not used now
|
||||
return -1, -1
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||
return self._worker.sampler_output(
|
||||
execute_model_req, sample_len,
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
if self._is_dummy:
|
||||
return SpeculativeProposals(None, None, None)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
if self._is_dummy:
|
||||
return _DummyModel()
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
if self._is_dummy:
|
||||
return []
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
if self._is_dummy:
|
||||
# by returning zero, target worker can use the entire kv cache space
|
||||
return 0
|
||||
|
||||
return self._worker.get_cache_block_size_bytes()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._worker.vocab_size
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
weight_loader = getattr(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
"weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
lm_head_weight)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,45 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
|
||||
class TargetModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding target model.
|
||||
In speculative decoding, the log probabilities selected finally may not
|
||||
be the same ones as selected by the target model sampling. This means
|
||||
that the time spent in the log probability calculation of the target model
|
||||
is time wasted, since we calculate log probabilities after deciding which
|
||||
tokens are accepted. For this reason disabling log probabilities in the
|
||||
target model will make decode faster. The model runner sets the
|
||||
SamplingMetadata parameters according to whether log probabilities are
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
super().__init__(model_runner)
|
||||
self.disable_logprobs = True
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelRunnerInputBase:
|
||||
model_input: ModelRunnerInputBase =\
|
||||
self.model_runner.prepare_model_input(
|
||||
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||
# If token log probabilities is disabled then skip generating sampler
|
||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||
# as needed. If log probabilities is enabled then synchronize all the
|
||||
# sampling related tensors which includes the logprobs tensors.
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
self.disable_logprobs)
|
||||
return model_input
|
||||
@ -1,275 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: ProposerWorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
):
|
||||
self._worker = worker
|
||||
self._device = device
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
hidden_states = execute_model_req.previous_hidden_states
|
||||
if hidden_states is not None:
|
||||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
previous_hidden_states=hidden_states,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
transposed)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=maybe_sampler_output
|
||||
is None)
|
||||
return proposals
|
||||
|
||||
def _split_by_proposal_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Split sequences by two groups:
|
||||
1. Sequences with non-zero proposal length.
|
||||
2. Sequences with zero proposal length (due to disabled speculation
|
||||
or exceed the maximum model length).
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
# The speculative decoding for this request has either been disabled
|
||||
# (e.g. due to high traffic) or this is a prompt request.
|
||||
if (seq_group_metadata.is_prompt
|
||||
or seq_group_metadata.num_speculative_tokens == 0):
|
||||
proposal_lens.append(0)
|
||||
continue
|
||||
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall not exceed this
|
||||
# quota for nonzero_proposal
|
||||
new_k = 0
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
new_k = proposal_len
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
proposal_lens.append(new_k)
|
||||
seq_group_metadata.num_speculative_tokens = new_k
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices, transposed):
|
||||
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||
their proposal_len to 0 the draft worker does not provide a proposal
|
||||
(maybe_sampler_output=None). This can avoid scoring overheads.
|
||||
"""
|
||||
|
||||
# If maybe_sampler_output is None, then the draft worker did not
|
||||
# provide a proposal for any sequence and thus no action needed.
|
||||
# Also we do not support transposed maybe_sampler_output for now
|
||||
# because it seems not straightforward for draft workers outputting
|
||||
# transposed sampler outputs to handle the case of no proposal.
|
||||
if maybe_sampler_output is None or transposed:
|
||||
return (proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
new_proposal_lens: List[int] = []
|
||||
new_nonzero_proposal_len_indices: List[int] = []
|
||||
new_maybe_sampler_output: List[SamplerOutput] = []
|
||||
nonzero_proposal_len_idx_ptr = 0
|
||||
seq_idx = 0
|
||||
while seq_idx < len(
|
||||
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
|
||||
nonzero_proposal_len_indices):
|
||||
if seq_idx < nonzero_proposal_len_indices[
|
||||
nonzero_proposal_len_idx_ptr]:
|
||||
# Sequence is not in the original nonzero_proposal_len_indices,
|
||||
# meaning that it has a proposal length of 0 before sending to
|
||||
# the draft worker.
|
||||
assert proposal_lens[seq_idx] == 0
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# Sequence is in the original nonzero_proposal_len_indices
|
||||
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
|
||||
# but does not have a proposal from the draft worker.
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# and has a proposal from the draft worker. Add it to the
|
||||
# new nonzero proposal list and keep the sampler output.
|
||||
new_proposal_lens.append(proposal_lens[seq_idx])
|
||||
new_nonzero_proposal_len_indices.append(seq_idx)
|
||||
new_maybe_sampler_output.append(
|
||||
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
|
||||
nonzero_proposal_len_idx_ptr += 1
|
||||
seq_idx += 1
|
||||
|
||||
# The remaining sequences should have proposal length of 0.
|
||||
new_proposal_lens.extend(proposal_lens[seq_idx:])
|
||||
|
||||
# We assume sampler_output will not be a list of all Nones.
|
||||
# In this case this function should not be called.
|
||||
assert new_maybe_sampler_output
|
||||
return (new_proposal_lens, new_maybe_sampler_output,
|
||||
new_nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
@ -1,277 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
If the sampling params do not call for any logprobs, return 0 for that
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
return all_num_logprobs
|
||||
|
||||
|
||||
def get_sampled_token_logprobs(
|
||||
# shape [num_steps, batch_size, vocab_size]
|
||||
logprob_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
||||
"""
|
||||
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
||||
|
||||
selected_logprobs = logprob_tensor[
|
||||
torch.arange(num_steps).unsqueeze(1),
|
||||
torch.arange(batch_size),
|
||||
sampled_token_ids,
|
||||
]
|
||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||
-1, -1, vocab_size)
|
||||
sampled_token_ids_ranks = (logprob_tensor
|
||||
> expanded_selected_logprobs).sum(-1).add_(1)
|
||||
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_logprobs_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
) -> Dict[int, Logprob]:
|
||||
"""Create a Logprob Dict for a token given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
),
|
||||
}
|
||||
logprobs.update({
|
||||
topk_token_id: Logprob(
|
||||
logprob=topk_logprob if topk_logprob is not None else 0.0,
|
||||
rank=topk_index + 1,
|
||||
)
|
||||
for topk_index, (topk_token_id, topk_logprob) \
|
||||
in enumerate(zip(topk_token_ids, topk_logprobs)) \
|
||||
if topk_token_id is not None
|
||||
})
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
step_index: (Optional[int]): The index of the speculative token.
|
||||
"""
|
||||
|
||||
logprobs = create_logprobs_output(
|
||||
token_id,
|
||||
token_id_logprob_rank,
|
||||
token_id_logprob,
|
||||
topk_token_ids,
|
||||
topk_logprobs,
|
||||
)
|
||||
|
||||
return CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
step_index=step_index)
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int],
|
||||
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
|
||||
List[SequenceGroupMetadata], List[int]]]:
|
||||
"""Utility function that splits a batch based on whether the proposal len is
|
||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||
lens in a batch.
|
||||
"""
|
||||
|
||||
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
for i, (seq_group, proposal_len) in enumerate(
|
||||
zip(seq_group_metadata_list, proposal_lens)):
|
||||
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
|
||||
seq_groups.append(seq_group)
|
||||
indices.append(i)
|
||||
return nonzero_lists, zero_lists
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
we need do additional tensor transpose logic here.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
|
||||
sampled_token_probs: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list), vocab_size]
|
||||
"""
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_probs = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_probs
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_ids.flatten()
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
if sampler_output_list[0].hidden_states is not None:
|
||||
# shape: [batch_size, num_sampler_output, hidden_dim]
|
||||
sampled_hidden_states = torch.stack(
|
||||
[
|
||||
sampler_output.hidden_states
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
|
||||
else:
|
||||
sampled_hidden_states = None
|
||||
|
||||
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
|
||||
sampled_hidden_states)
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
vocab_size: int, device: str) -> None:
|
||||
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
|
||||
values. This will be removed in PR 7/9.
|
||||
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
"""
|
||||
values = [
|
||||
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
|
||||
]
|
||||
assert all(v is None for v in values) or not any(v is None for v in values)
|
||||
if not any(v is None for v in values):
|
||||
# Do nothing if the tensors are already created (usually in unit tests).
|
||||
return
|
||||
|
||||
# Softmax to ensure valid probs.
|
||||
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
|
||||
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
|
||||
dim=-1)
|
||||
|
||||
sampler_output.sampled_token_ids = torch.randint(low=10,
|
||||
high=100,
|
||||
size=(batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(msg, *args, **kwargs):
|
||||
"""
|
||||
Context manager / decorator that pushes an NVTX range at the beginning
|
||||
of its scope, and pops it at the end. If extra arguments are given,
|
||||
they are passed as arguments to msg.format().
|
||||
|
||||
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
||||
|
||||
Arguments:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Basic timer context manager for measuring CPU time.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.end_time = time.time()
|
||||
self.elapsed_time_s = self.end_time - self.start_time
|
||||
self.elapsed_time_ms = self.elapsed_time_s * 1000
|
||||
@ -6,7 +6,6 @@ from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
|
||||
|
||||
@ -44,28 +43,25 @@ class EAGLEConfig(PretrainedConfig):
|
||||
self.truncated_vocab_size = self.model.vocab_size if \
|
||||
truncated_vocab_size is None else truncated_vocab_size
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
kwargs["architectures"] = ["EAGLEModel"]
|
||||
# Eagle model name should follow naming convention of
|
||||
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
||||
if method == "eagle":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle"
|
||||
kwargs["architectures"] = [
|
||||
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
||||
else arch for arch in self.model.architectures
|
||||
]
|
||||
elif method == "eagle3":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle3"
|
||||
kwargs["architectures"] = [
|
||||
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
|
||||
else arch for arch in self.model.architectures
|
||||
]
|
||||
else:
|
||||
# Eagle model name should follow naming convention of
|
||||
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
||||
if method == "eagle":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle"
|
||||
kwargs["architectures"] = [
|
||||
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
||||
else arch for arch in self.model.architectures
|
||||
]
|
||||
elif method == "eagle3":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle3"
|
||||
kwargs["architectures"] = [
|
||||
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
|
||||
else arch for arch in self.model.architectures
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid method {method}. \
|
||||
Supported methods are eagle and eagle3.")
|
||||
raise ValueError(f"Invalid method {method}. \
|
||||
Supported methods are eagle and eagle3.")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@ -397,8 +397,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
if execute_model_req is not None and execute_model_req.spec_step_idx:
|
||||
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user