[V1] LoRA - Add triton kernels for V1 (#13096)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-03-10 17:27:53 -04:00
committed by GitHub
parent 0967110e42
commit 5ff0d32580
11 changed files with 1165 additions and 191 deletions

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
import random
from copy import deepcopy
from dataclasses import dataclass
@ -63,6 +64,36 @@ DEVICES = ([
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
# the tests in this file run twice, once with the V0 engine and then with
# the V1 engine.
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
# with the inclusion of V1 tests to maintain the CI test times.
NUM_RANDOM_SEEDS = 5
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
# the CI test times.
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
# Reload punica_gpu as the kernels used are tied to engine type.
from vllm.lora.punica_wrapper import punica_gpu
importlib.reload(punica_gpu)
# Release any memory we might be holding on to. CI runs OOMs otherwise.
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
_LORA_B_PTR_DICT)
_LORA_B_PTR_DICT.clear()
_LORA_A_PTR_DICT.clear()
yield
def get_random_id_to_index(num_loras: int,
num_slots: int,
@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
return embedding, lora_embedding
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
return expanded_embedding, lora_embedding
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
return linear, logits_processor, lora_logits_processor
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16,
@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
is_neox_style, rotary_dim, head_size,
seq_len) -> None:
dtype = torch.float16
max_loras = 8
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
@pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
random.seed(seed)
vocab_size = random.randint(4000, 64000)

View File

@ -5,10 +5,12 @@ import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
import vllm.lora.ops.triton_ops.v1 # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data,
@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock()
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
"""
Compare outputs of vllm.sgmv_shrink kernel against a reference
implementation.
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
)
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer.
with _dict_lock:
# SGMV shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
sgmv_out_tensor,
*sgmv_meta_args,
scaling,
)
sgmv_shrink_for_nslices(
nslices,
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
scaling,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Reference
sgmv_shrink_for_nslices(
nslices,
data.inputs_tensor,
data.lora_weights,
ref_out_tensor,
*sgmv_meta_args,
scaling,
)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
def check_expand_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
"""
Compare outputs of vllm.sgmv_expand kernel against a reference
implementation.
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int,
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
with _dict_lock:
# SGMV expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
sgmv_out_tensor,
*sgmv_meta_args,
offset_start=0,
add_inputs=add_inputs,
)
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference
sgmv_expand_for_nslices(nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
ref_out_tensor,
*sgmv_meta_args,
add_inputs=add_inputs)
assert_close(data.our_out_tensor, data.ref_out_tensor)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
@ -439,7 +478,7 @@ SEED = [0]
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv(
def test_kernels(
batches: int,
num_loras: int,
rank: int,
@ -450,29 +489,32 @@ def test_punica_sgmv(
seed: int,
op_type: str,
):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@ -484,7 +526,7 @@ def test_punica_sgmv(
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv_hidden_size(
def test_kernels_hidden_size(
batches: int,
num_loras: int,
rank: int,
@ -495,29 +537,32 @@ def test_punica_sgmv_hidden_size(
seed: int,
op_type: str,
):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])