[Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)
Signed-off-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
|
||||
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
|
||||
"the ability to pass in uniform samples.")
|
||||
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
|
||||
@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
|
||||
@ -1,14 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
from torch import Generator
|
||||
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
is_flashinfer_available)
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||
|
||||
|
||||
def test_topk_impl_equivalance():
|
||||
|
||||
@ -35,3 +41,67 @@ def test_topk_impl_equivalance():
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
||||
|
||||
|
||||
def test_flashinfer_sampler():
|
||||
'''
|
||||
This test verifies that the FlashInfer top-k and top-p sampling
|
||||
implementation produces the same results as the Python implementation.
|
||||
|
||||
NOTE: FlashInfer did not directly expose an interface for fused top-k and
|
||||
top-p prob renorm (it did provide fused sampling but we cannot compare
|
||||
sampling results due to randomness), so we will compare the probability
|
||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||
'''
|
||||
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip(
|
||||
"FlashInfer not installed or not available on this platform.")
|
||||
|
||||
with torch.device(DEVICE):
|
||||
generator = Generator(device=DEVICE).manual_seed(42)
|
||||
|
||||
# Generate random logits
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Generate various top-k and top-p values
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
||||
p_values = torch.rand(
|
||||
(BATCH_SIZE, ),
|
||||
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
||||
|
||||
# Sometimes disable top-k (k=vocab_size)
|
||||
k_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), VOCAB_SIZE)
|
||||
|
||||
# Sometimes disable top-p (p=1.0)
|
||||
p_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), 1.0)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
logits=logits.clone(),
|
||||
k=k_values,
|
||||
p=p_values,
|
||||
)
|
||||
python_probs = torch.softmax(python_logits, dim=-1)
|
||||
|
||||
# FlashInfer only exposed renorm interfaces for probs so convert first
|
||||
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
|
||||
flashinfer_probs = top_k_renorm_probs(
|
||||
probs=flashinfer_probs,
|
||||
top_k=k_values,
|
||||
)
|
||||
flashinfer_probs = top_p_renorm_probs(
|
||||
probs=flashinfer_probs,
|
||||
top_p=p_values,
|
||||
)
|
||||
|
||||
# Compare the results
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
|
||||
Reference in New Issue
Block a user