[V1] LoRA - Add triton kernels for V1 (#13096)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
0967110e42
commit
5ff0d32580
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -63,6 +64,36 @@ DEVICES = ([
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
STAGES = [True, False]
|
||||
|
||||
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
|
||||
# the tests in this file run twice, once with the V0 engine and then with
|
||||
# the V1 engine.
|
||||
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
|
||||
# with the inclusion of V1 tests to maintain the CI test times.
|
||||
NUM_RANDOM_SEEDS = 5
|
||||
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
|
||||
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
|
||||
# the CI test times.
|
||||
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
|
||||
# Reload punica_gpu as the kernels used are tied to engine type.
|
||||
from vllm.lora.punica_wrapper import punica_gpu
|
||||
importlib.reload(punica_gpu)
|
||||
|
||||
# Release any memory we might be holding on to. CI runs OOMs otherwise.
|
||||
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
||||
_LORA_B_PTR_DICT)
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_random_id_to_index(num_loras: int,
|
||||
num_slots: int,
|
||||
@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
return embedding, lora_embedding
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
|
||||
return expanded_embedding, lora_embedding
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
|
||||
return linear, logits_processor, lora_logits_processor
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16,
|
||||
@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len) -> None:
|
||||
dtype = torch.float16
|
||||
max_loras = 8
|
||||
seed = 0
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
long_lora_scaling_factors=scaling_factors,
|
||||
@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("seed", list(range(256)))
|
||||
@pytest.mark.parametrize(
|
||||
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
||||
random.seed(seed)
|
||||
vocab_size = random.randint(4000, 64000)
|
||||
|
||||
@ -5,10 +5,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
import vllm.lora.ops.triton_ops.v1 # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (PunicaTensors, assert_close, generate_data,
|
||||
@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_shrink kernel against a reference
|
||||
implementation.
|
||||
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
|
||||
reference implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
)
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
# SGMV shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
scaling,
|
||||
)
|
||||
|
||||
sgmv_shrink_for_nslices(
|
||||
nslices,
|
||||
# V1 shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
scaling,
|
||||
)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
# Reference
|
||||
sgmv_shrink_for_nslices(
|
||||
nslices,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
scaling,
|
||||
)
|
||||
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
def check_expand_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_expand kernel against a reference
|
||||
implementation.
|
||||
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
|
||||
reference implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
# Setup output tensors
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
with _dict_lock:
|
||||
# SGMV expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
# V1 expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_expand(data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
# Reference
|
||||
sgmv_expand_for_nslices(nslices,
|
||||
hidden_size,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
@ -439,7 +478,7 @@ SEED = [0]
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv(
|
||||
def test_kernels(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
@ -450,29 +489,32 @@ def test_punica_sgmv(
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_shrink_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_expand_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@ -484,7 +526,7 @@ def test_punica_sgmv(
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv_hidden_size(
|
||||
def test_kernels_hidden_size(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
@ -495,29 +537,32 @@ def test_punica_sgmv_hidden_size(
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_shrink_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_expand_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
|
||||
Reference in New Issue
Block a user