[Deepseek v3.2] Remove extra logics in indexer (#26465)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Lain <siyuanf@nvidia.com> Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
@ -10,6 +10,8 @@ from vllm.platforms import current_platform
|
||||
# Test parameters
|
||||
NUM_ROWS = [1, 32, 2050]
|
||||
TOP_K_VALUES = [2048]
|
||||
BATCH_SIZE = [1, 2, 4, 2048, 4096]
|
||||
NEXT_N = [1, 2, 4, 8]
|
||||
|
||||
|
||||
def create_random_logits(
|
||||
@ -114,7 +116,7 @@ def test_top_k_per_row(
|
||||
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
||||
|
||||
# Create output tensors
|
||||
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Run CUDA implementation
|
||||
torch.ops._C.top_k_per_row(
|
||||
@ -138,3 +140,59 @@ def test_top_k_per_row(
|
||||
assert compare_top_k_results(
|
||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||
), "CUDA top_k_per_row results don't match torch.topk"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("next_n", NEXT_N)
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@torch.inference_mode()
|
||||
def test_top_k_per_row_decode(
|
||||
top_k: int,
|
||||
batch_size: int,
|
||||
next_n: int,
|
||||
) -> None:
|
||||
"""
|
||||
Test top_k_per_row with seq_lens tensor.
|
||||
"""
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
# Create test data
|
||||
num_rows = batch_size * next_n
|
||||
vocab_size = 20000
|
||||
seq_lens = torch.randint(
|
||||
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
|
||||
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
||||
|
||||
# Create output tensors
|
||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Run CUDA implementation
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
seq_lens,
|
||||
indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run reference implementation
|
||||
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||
mask_lo = torch_indices >= 0
|
||||
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||
mask = mask_lo & mask_hi
|
||||
torch_indices = torch_indices.masked_fill(~mask, -1)
|
||||
|
||||
# Compare results
|
||||
assert compare_top_k_results(
|
||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||
), "CUDA top_k_per_row results don't match torch.topk"
|
||||
|
||||
Reference in New Issue
Block a user