Compare commits

..

4 Commits

Author SHA1 Message Date
031c8b32a4 Add time comment
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-03-17 13:50:44 +00:00
ac08d45200 Merge branch 'main' into mamba_tests 2025-03-17 13:49:56 +00:00
a5d29e9ee1 undo massive formatting change
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-03-15 17:31:21 +00:00
696245c2fc Add SSM and Hybrid Models Test
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-03-15 17:26:01 +00:00
54 changed files with 2276 additions and 1501 deletions

View File

@ -200,7 +200,6 @@ steps:
- pytest -v -s v1/core
- pytest -v -s v1/entrypoints
- pytest -v -s v1/engine
- pytest -v -s v1/entrypoints
- pytest -v -s v1/sample
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
@ -454,6 +453,15 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model'
- label: SSM and Hybrid Models Test # 12min
source_file_dependencies:
- vllm/
- tests/models/decoder_only/language/test_hybrid.py
- tests/models/decoder_only/language/test_mamba.py
commands:
- pytest -v -s models/decoder_only/language/test_hybrid.py
- pytest -v -s models/decoder_only/language/test_mamba.py
# This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test
optional: true

View File

@ -17,8 +17,13 @@ from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -162,25 +167,69 @@ class OpType(Enum):
"""
LoRA Ops to benchmark and its properties.
"""
LORA_SHRINK = auto()
LORA_EXPAND = auto()
SGMV_SHRINK = auto()
BGMV_SHRINK = auto()
SGMV_EXPAND = auto()
BGMV_EXPAND = auto()
BGMV_EXPAND_SLICE = auto()
V1_SHRINK = auto()
V1_EXPAND = auto()
@staticmethod
def from_str(s: str) -> "OpType":
if s.lower() == "lora_shrink":
return OpType.LORA_SHRINK
if s.lower() == "lora_expand":
return OpType.LORA_EXPAND
if s.lower() == 'sgmv_shrink':
return OpType.SGMV_SHRINK
if s.lower() == 'sgmv_expand':
return OpType.SGMV_EXPAND
if s.lower() == 'bgmv_shrink':
return OpType.BGMV_SHRINK
if s.lower() == 'bgmv_expand':
return OpType.BGMV_EXPAND
if s.lower() == "bgmv_expand_slice":
return OpType.BGMV_EXPAND_SLICE
if s.lower() == "v1_shrink":
return OpType.V1_SHRINK
if s.lower() == "v1_expand":
return OpType.V1_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool:
return self in [OpType.LORA_SHRINK]
return self in [
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
]
def is_expand_fn(self) -> bool:
return self in [OpType.LORA_EXPAND]
return self in [
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
]
def is_prefill_op(self) -> bool:
return self in [
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
OpType.V1_EXPAND
]
def is_decode_op(self) -> bool:
return self in [
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
OpType.V1_SHRINK, OpType.V1_EXPAND
]
def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> list[int]:
return [1, 2, 3]
if self in [
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
OpType.V1_EXPAND
]:
# SGMV kernels and v1 kernels supports slices
return [1, 2, 3]
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
return [1]
if self in [OpType.BGMV_EXPAND_SLICE]:
return [2, 3]
raise ValueError(f"Unrecognized OpType {self}")
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int) -> tuple[int, int, int]:
@ -190,7 +239,7 @@ class OpType(Enum):
k = hidden_size
n = lora_rank
else:
assert self.is_expand_fn()
assert self.is_expand_fn() or self.is_expand_slice_fn()
m = num_tokens
k = lora_rank
n = hidden_size
@ -205,7 +254,7 @@ class OpType(Enum):
if self.is_shrink_fn():
return op_dtype, op_dtype, torch.float32
else:
assert self.is_expand_fn()
assert self.is_expand_fn() or self.is_expand_slice_fn()
return torch.float32, op_dtype, op_dtype
def matmul_shapes(
@ -219,19 +268,43 @@ class OpType(Enum):
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
b_shape = (num_loras, n, k) # col-major
if self in [OpType.LORA_SHRINK]:
# LoRA shrink kernels support num_slices inherently in the kernel.
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
# SGMV shrink and V1 shrink kernels support num_slices inherently
# in the kernel.
return ((m, k), b_shape, (num_slices, m, n))
if self in [OpType.LORA_EXPAND]:
# LoRA expand kernels support num_slices inherently in the kernel
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
# SGMV expand and V1 expand kernels support num_slices inherently
# in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self == OpType.BGMV_SHRINK:
return ((m, k), b_shape, (m, n))
if self == OpType.BGMV_EXPAND:
return ((m, k), b_shape, (m, n))
if self == OpType.BGMV_EXPAND_SLICE:
return ((num_slices, m, k), b_shape, (m, n * num_slices))
raise ValueError(f"Unrecognized op_type {self}")
def bench_fn(self) -> Callable:
if self == OpType.LORA_SHRINK:
return lora_shrink
if self == OpType.LORA_EXPAND:
return lora_expand
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
for x in kwargs_list:
bgmv_expand_slice(**x)
if self == OpType.SGMV_SHRINK:
return sgmv_shrink
if self == OpType.SGMV_EXPAND:
return sgmv_expand
if self == OpType.BGMV_SHRINK:
return bgmv_shrink
if self == OpType.BGMV_EXPAND:
return bgmv_expand
if self == OpType.BGMV_EXPAND_SLICE:
return emulate_bgmv_expand_slice
if self == OpType.V1_SHRINK:
return v1_shrink
if self == OpType.V1_EXPAND:
return v1_expand
raise ValueError(f"Unrecognized optype {self}")
@ -245,13 +318,34 @@ class OpType(Enum):
"""
w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights)
if self in [OpType.LORA_SHRINK]:
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :],
input=input,
lora_weights=lora_weights[slice_idx],
**kwargs)
elif self in [OpType.LORA_EXPAND]:
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
ref_group_gemm(
ref_out=output[:, slice_offset:slice_offset + hidden_size],
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
elif self == OpType.BGMV_SHRINK:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input,
lora_weights=lora_weights[0],
**kwargs)
elif self == OpType.BGMV_EXPAND:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input.clone().to(dtype=w_dtype),
lora_weights=lora_weights[0],
**kwargs)
elif self == OpType.BGMV_EXPAND_SLICE:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
@ -317,11 +411,13 @@ class BenchmarkTensors:
input: torch.Tensor
lora_weights_lst: list[torch.Tensor]
output: torch.Tensor
# LoRA kernel metadata
lora_kernel_meta: LoRAKernelMeta
# Metadata tensors used in testing correctness
# metadata tensors
seq_lens: torch.Tensor
seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor
token_lora_mapping: torch.Tensor
# v1 kernel metadata
v1_kernel_meta: Optional[V1KernelMeta] = None
def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x"
@ -348,29 +444,35 @@ class BenchmarkTensors:
assert ctx.num_active_loras <= ctx.num_loras
total_tokens = ctx.batch_size * ctx.seq_length
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
(ctx.batch_size, ))
# Prepare seq_start_loc tensor
seq_start_loc_tensor = torch.cumsum(torch.tensor(
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0)
assert total_tokens == seq_len_tensor.sum()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor = make_prompt_lora_mapping(
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
# Make LoRAKernelMeta
# Prepare token lora indices tensor
token_lora_indices_tensor = make_token_lora_mapping(
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu")
lora_kernel_meta = LoRAKernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
lora_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
v1_kernel_meta = None
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
v1_kernel_meta = V1KernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
v1_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
lora_kernel_meta, seq_len_tensor,
prompt_lora_indices_tensor)
seq_len_tensor, seq_start_loc_tensor,
prompt_lora_indices_tensor,
token_lora_indices_tensor, v1_kernel_meta)
def sanity_check(self) -> None:
"""
@ -380,9 +482,9 @@ class BenchmarkTensors:
# check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens
num_seqs = self.seq_lens.shape[0]
#assert self.seq_start_loc.shape[0] == num_seqs
assert self.seq_start_loc.shape[0] == num_seqs
assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
assert self.token_lora_mapping.shape[0] == num_tokens
def to_device(self, device: str):
"""
@ -397,27 +499,220 @@ class BenchmarkTensors:
self.input = to_device(self.input)
self.output = to_device(self.output)
self.seq_lens = to_device(self.seq_lens)
self.seq_start_loc = to_device(self.seq_start_loc)
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
self.token_lora_mapping = to_device(self.token_lora_mapping)
for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
# LoRA meta
for field_name in LoRAKernelMeta.__dataclass_fields__:
field = getattr(self.lora_kernel_meta, field_name)
assert isinstance(field, torch.Tensor)
setattr(self.lora_kernel_meta, field_name, to_device(field))
# v1 meta
if self.v1_kernel_meta:
for field_name in V1KernelMeta.__dataclass_fields__:
field = getattr(self.v1_kernel_meta, field_name)
assert isinstance(field, torch.Tensor)
setattr(self.v1_kernel_meta, field_name, to_device(field))
def metadata(self) -> tuple[int, int, int]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs = self.seq_lens.shape[0]
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
num_tokens = self.token_lora_mapping.shape[0]
max_seq_len = torch.max(self.seq_lens).item()
num_slices = len(self.lora_weights_lst)
return num_seqs, num_tokens, max_seq_len, num_slices
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
def convert_to_sgmv_benchmark_tensors(self):
"""
For sgmv punica kernels, when consecutive sequences have the
same LoRA ID, we just merge them together.
This happens in punica.py::compute_metadata
"""
# Collapse seq_lens and seq_start_loc
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
return_counts=True)
cum_result = torch.cumsum(seq_lens, dim=0)
seq_start_loc = torch.zeros_like(seq_lens)
seq_start_loc[1:].copy_(cum_result[:-1])
# Collapse prompt mapping
prompt_lora_mapping = torch.unique_consecutive(
self.prompt_lora_mapping)
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
self.prompt_lora_mapping = prompt_lora_mapping.to(
dtype=self.prompt_lora_mapping.dtype)
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
self.to_device(self.input.device)
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert len(o_shape) == 3
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'b_seq_start_loc': self.seq_start_loc,
'seq_len_tensor': self.seq_lens,
'lora_indices_tensor': self.prompt_lora_mapping,
'batches': num_seqs,
'max_seq_length': max_seq_len,
'token_nums': num_tokens,
'scaling': 1.0,
}
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
self.to_device(self.input.device)
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'b_seq_start_loc': self.seq_start_loc,
'seq_len_tensor': self.seq_lens,
'lora_indices_tensor': self.prompt_lora_mapping,
'batches': num_seqs,
'max_seq_length': max_seq_len,
'token_nums': num_tokens,
'offset_start': 0,
'add_inputs': add_inputs,
}
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device)
_, num_tokens, _, _ = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_tokens, lora_rank]
assert len(o_shape) == 2
assert o_shape == (num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst[0],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'scaling': 1.0
}
def as_bgmv_expand_kwargs(self, add_inputs: bool):
assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device)
_, num_tokens, _, _ = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, lora_rank]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
lora_rank = i_shape[1]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape [num_tokens, hidden_size]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst[0],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'add_inputs': add_inputs
}
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
self.to_device(self.input.device)
kwargs_list = []
for i in range(num_slices):
kwargs_list.append({
'inputs': self.input[i],
'lora_b_weights': self.lora_weights_lst[i],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'slice_offset': i * hidden_size,
'slice_size': hidden_size,
'add_inputs': add_inputs,
})
return {'kwargs_list': kwargs_list}
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
@ -442,16 +737,17 @@ class BenchmarkTensors:
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.lora_kernel_meta.active_lora_ids,
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'scaling': 1.0,
}
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
@ -477,12 +773,12 @@ class BenchmarkTensors:
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.lora_kernel_meta.active_lora_ids,
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'offset_start': 0,
'add_inputs': add_inputs,
}
@ -495,10 +791,20 @@ class BenchmarkTensors:
else:
assert add_inputs is not None
if op_type == OpType.LORA_SHRINK:
return self.as_lora_shrink_kwargs()
if op_type == OpType.LORA_EXPAND:
return self.as_lora_expand_kwargs(add_inputs)
if op_type == OpType.SGMV_SHRINK:
return self.as_sgmv_shrink_kwargs()
if op_type == OpType.SGMV_EXPAND:
return self.as_sgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_SHRINK:
return self.as_bgmv_shrink_kwargs()
if op_type == OpType.BGMV_EXPAND:
return self.as_bgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_EXPAND_SLICE:
return self.as_bgmv_expand_slice_kwargs(add_inputs)
if op_type == OpType.V1_SHRINK:
return self.as_v1_shrink_kwargs()
if op_type == OpType.V1_EXPAND:
return self.as_v1_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}")
def test_correctness(self, op_type: OpType,
@ -687,6 +993,10 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths:
bench_ops: list[OpType] = args.op_types
if seq_len > 1:
# bench only prefill ops
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
seq_len_timers = []
for bench_op in bench_ops:
for num_slices in bench_op.num_slices():
@ -896,13 +1206,13 @@ Benchmark LoRA kernels:
{use_cuda_graph_recommendation()}
list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)

View File

@ -477,11 +477,6 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
*
*
:::
:::{note}
@ -768,7 +763,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
*
* ⚠️
- * `GLM4VForCausalLM`<sup>^</sup>
* GLM-4V
* T + I
@ -953,11 +948,8 @@ V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
- Will be updated in the future to support the correct behavior
- Does not support `"do_pan_and_scan": True`
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
For these reasons, `Gemma3ForConditionalGeneration` is supported only on V0 at the moment.
:::
:::{note}

View File

@ -93,6 +93,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
)

View File

@ -682,6 +682,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
return ModelRequestData(

View File

@ -342,6 +342,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
placeholders = "".join(f"<|image_{i}|>"

View File

@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs

View File

@ -235,7 +235,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.5.4
mistral-common==1.5.1
# via -r requirements/test.in
more-itertools==10.5.0
# via lm-eval

View File

@ -4,13 +4,18 @@ from threading import Lock
import pytest
import torch
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
import vllm.lora.ops.triton_ops # noqa: F401
import vllm.lora.ops.triton_ops.v1 # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
from .utils import (PunicaTensors, assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
# Utility shrink and expand operations used as reference implementations.
@ -21,10 +26,10 @@ def sgmv_shrink_for_nslices(
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float):
"""
Wrapper around torch_ops.sgmv_shrink that handles any nslices.
Wrapper around sgmv_shrink that handles any nslices.
"""
for index in range(nslices):
torch_ops.sgmv_shrink(
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
out_tensor[index],
@ -48,11 +53,11 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None:
"""
Wrapper around torch_ops.sgmv_expand that handles any nslices.
Wrapper around sgmv_expand that handles any nslices.
"""
if nslices == 1:
# Verify the torch's sgmv_expand op
torch_ops.sgmv_expand(
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
out_tensor,
@ -68,7 +73,7 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
torch_ops.sgmv_expand_slice(
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
out_tensor,
@ -88,13 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock()
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
scaling: float):
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
"""
Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
kernels.
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -114,24 +118,35 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta.prepare_tensors(data.token_lora_mapping)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
out_tensor = data.our_out_tensor.clone()
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer.
with _dict_lock:
# lora_shrink kernel
# SGMV shrink kernel
_LORA_A_PTR_DICT.clear()
triton_ops.lora_shrink(
torch.ops.vllm.sgmv_shrink(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
sgmv_out_tensor,
*sgmv_meta_args,
scaling,
)
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
scaling,
)
@ -145,16 +160,16 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
scaling,
)
assert_close(out_tensor, ref_out_tensor)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
add_inputs: bool):
def check_expand_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
"""
Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
kernels.
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -175,25 +190,37 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta.prepare_tensors(data.token_lora_mapping)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
ref_out_tensor = data.ref_out_tensor
out_tensor = data.our_out_tensor.clone()
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
with _dict_lock:
# lora_expand kernel
# SGMV expand kernel
_LORA_B_PTR_DICT.clear()
triton_ops.lora_expand(data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
sgmv_out_tensor,
*sgmv_meta_args,
offset_start=0,
add_inputs=add_inputs,
)
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference
sgmv_expand_for_nslices(nslices,
@ -204,7 +231,124 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
*sgmv_meta_args,
add_inputs=add_inputs)
assert_close(out_tensor, ref_out_tensor)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
scaling: float):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
torch.ops.vllm.bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
add_inputs: bool):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
torch.ops.vllm.bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, add_inputs: bool):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
torch.ops.vllm.bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Tests
@ -346,31 +490,31 @@ def test_kernels(
op_type: str,
):
"""
Tests LoRA kernels.
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@ -394,28 +538,159 @@ def test_kernels_hidden_size(
op_type: str,
):
"""
Tests SGMV and LoRA kernels.
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str,
seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
rank: int, hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str, seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import json
import pickle
import pytest
@ -209,6 +208,8 @@ def test_guided_decoding_backend_options():
def test_pickle_xgrammar_tokenizer_data():
# TODO: move to another test file for xgrammar
try:
import xgrammar as xgr
except ImportError:
@ -216,11 +217,7 @@ def test_pickle_xgrammar_tokenizer_data():
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
TokenizerData)
tokenizer_data = TokenizerData(
metadata=
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
encoded_vocab=['!', '"', '#', '$', '%'],
)
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
pickled = pickle.dumps(tokenizer_data)
assert pickled is not None
@ -228,5 +225,4 @@ def test_pickle_xgrammar_tokenizer_data():
depickled: TokenizerData = pickle.loads(pickled)
assert depickled is not None
assert json.loads(
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value
assert depickled.vocab_type == xgr.VocabType.RAW

View File

@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
MODELS = ["ai21labs/Jamba-tiny-dev"]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
@ -27,19 +27,17 @@ def test_models(
) -> None:
# numeric error produces different generation
if "Bamba" in model:
if 'Bamba' in model:
example_prompts.pop(3)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
@ -114,31 +112,26 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chunking produces different generation
# numeric error during prefill chucking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if "Jamba" in model:
if 'Jamba' in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif "Bamba" in model:
elif 'Bamba' in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,

View File

@ -100,6 +100,7 @@ def run_test(
distributed_executor_backend=distributed_executor_backend,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
enforce_eager=True,
) as vllm_model:

View File

@ -195,8 +195,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
min_transformers_version="4.49"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),

View File

@ -18,6 +18,9 @@ MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
]
# Undo after https://github.com/vllm-project/vllm/pull/14868
pytest.skip(allow_module_level=True)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",

View File

@ -821,11 +821,6 @@ class ModelConfig:
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
return self.hf_text_config.attention_head_dim
if self.is_attention_free:
return 0
@ -909,9 +904,7 @@ class ModelConfig:
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
# the layout order is: DP x PP x TP
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
) % parallel_config.pipeline_parallel_size
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return start, end
@ -949,15 +942,6 @@ class ModelConfig:
"cannot determine the num of "
f"{block_type.value} layers")
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
@ -2324,7 +2308,7 @@ class LoRAConfig:
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
possible_lora_extra_vocab_size = (256, 512)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "

View File

@ -897,23 +897,10 @@ def initialize_model_parallel(
get_world_group().device_group)
data_parallel_size = 1
has_external_dp = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
if config.parallel_config.world_size != world_size:
# detect external data parallelism.
# dp in vllm means all dp instances need to run together.
# if the world size does not match, it means this dp is external,
# and the dp instances can run independently, e.g. in rlhf workflow
# from https://github.com/volcengine/verl .
# in that case, we treat the rest dimensions as if they are
# data parallel, and create a dummy dp group that is not used.
data_parallel_size = world_size // (pipeline_model_parallel_size *
tensor_model_parallel_size)
has_external_dp = True
else:
data_parallel_size = config.parallel_config.data_parallel_size
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: DP x PP x TP
# to get group_ranks for each dimension, transpose that dimension to the
@ -953,12 +940,6 @@ def initialize_model_parallel(
2).reshape(-1,
data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if has_external_dp:
# create a dummy dp group that is not used actually,
# since this dp is external.
# a dummy dp group means every rank is a group itself.
# this way, no communication is needed, no memory is wasted.
group_ranks = [[x] for x in range(world_size)]
_DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,

View File

@ -3,7 +3,6 @@
import argparse
import dataclasses
import json
import threading
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args)
@ -1192,7 +1191,7 @@ class EngineArgs:
NOTE: for autoselection of V0 vs V1 engine, we need to
create the ModelConfig first, since ModelConfig's attrs
(e.g. the model arch) are needed to make the decision.
This function set VLLM_USE_V1=X if VLLM_USE_V1 is
unspecified by the user.
@ -1577,9 +1576,8 @@ class EngineArgs:
#############################################################
# Experimental Features - allow users to opt in.
# Signal Handlers requires running in main thread.
if (threading.current_thread() != threading.main_thread()
and _warn_or_fallback("Engine in background thread")):
# MLA is is supported on V1, but off by default for now.
if model_config.use_mla and _warn_or_fallback("MLA"):
return False
# LoRA is supported on V1, but off by default for now.

View File

@ -29,8 +29,6 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError
@ -44,12 +42,12 @@ class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
@ -430,9 +428,6 @@ def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,

View File

@ -82,8 +82,6 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
@ -223,9 +221,6 @@ async def build_async_engine_client_from_engine_args(
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.

View File

@ -1,11 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.lora_expand import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink import lora_shrink
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
__all__ = [
"lora_expand",
"lora_shrink",
"LoRAKernelMeta",
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_shrink",
]

View File

@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n_c[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=0.0)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
def bgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_expand",
op_func=_bgmv_expand,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_fake,
)
bgmv_expand = torch.ops.vllm.bgmv_expand
except AttributeError:
bgmv_expand = _bgmv_expand

View File

@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
slice_offset * cn_stride)
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store
# operation uses the same mask, eliminating the risk of garbage
# values propagating
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=None)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'b weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
def bgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_expand_slice",
op_func=_bgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_slice_fake,
)
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice

View File

@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
performance
"""
pid_sk = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_n = tl.arange(0, BLOCK_N)
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
a_ptr = input_ptr + cur_batch * xm_stride
b_ptr = lora_ptr + l0_stride * lora_index
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
for k in range(0, K, BLOCK_K * SPLIT_K):
current_k = k + offset_k
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
tiled_a = tl.load(
a_ptr + current_k_c,
mask=current_k < K,
other=0.0,
) # [BLOCK_K]
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
tiled_b = tl.load(
b_ptr + offset_n[:, None] * lora_k_stride +
current_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
accumulator += tl.sum(tiled_a * tiled_b, 1)
accumulator *= scaling
offset_cn = tl.arange(0, BLOCK_N)
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
c_mask = offset_cn < N
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_a_weights.size(-1)
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N)
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)
grid = lambda META: (
META["SPLIT_K"],
batches,
)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_N=BLOCK_N,
**config,
)
return
def bgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_shrink",
op_func=_bgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=bgmv_shrink_fake,
)
bgmv_shrink = torch.ops.vllm.bgmv_shrink
except AttributeError:
bgmv_shrink = _bgmv_shrink

View File

@ -0,0 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_expand_kernel
from .utils import _get_lora_b_ptr
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
slice_id = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M >= M:
return
if pid_n * BLOCK_N >= curr_N:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
m_offset = tl.load(b_seq_start_loc + cur_batch)
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
)
@torch.inference_mode()
def _sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == token_nums
assert inputs.size(0) == len(lora_b_weights)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert output_tensor.is_contiguous()
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
b_seq_start_loc.device)
# TODO tuning this config
K = lora_b_weights[0].shape[-1] # K= rank
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
batches,
len(lora_b_weights),
)
_sgmv_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
MAX_N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
len(lora_b_weights),
same_stride,
)
return
def _sgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="sgmv_expand",
op_func=_sgmv_expand,
mutates_args=["output_tensor"],
fake_impl=_sgmv_expand_fake,
)
sgmv_expand = torch.ops.vllm.sgmv_expand
except AttributeError:
sgmv_expand = _sgmv_expand

View File

@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_shrink_kernel
from .utils import _get_lora_a_ptr
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
lora_ptr, #1-3
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
input_d0_stride,
input_d1_stride, # 1
lora_d0_stride,
lora_d1_stride,
lora_d2_stride, # 1
output_d0_stride,
output_d1_stride,
output_d2_stride, # 1
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_mix = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
if SLICE_NUM == 1:
slice_id: tl.constexpr = 0
pid_sk = tl.program_id(axis=1)
else:
pid_mix = tl.program_id(axis=1)
slice_id = pid_mix // SPLIT_K
pid_sk = pid_mix % SPLIT_K
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M >= M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
m_offset = tl.load(b_seq_start_loc + cur_batch)
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (List[torch.Tensor]): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
# TODO tuning this config
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K * len(lora_a_weights),
batches,
)
_sgmv_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
len(lora_a_weights),
)
return
def sgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="sgmv_shrink",
op_func=_sgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=sgmv_shrink_fake,
)
sgmv_shrink = torch.ops.vllm.sgmv_shrink
except AttributeError:
sgmv_shrink = _sgmv_shrink

View File

@ -1,9 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
import functools
from typing import Dict, List, Tuple
import torch
@functools.lru_cache
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
# TODO: add optimal configurations
return None
def _check_divisibility(hidden_size: int):
# The bgmv_expand kernel requires that the hidden_size be divisible by
# the number below.
divisibility = [2, 4, 8, 16, 32, 64]
divisibility.sort(reverse=True)
for div in divisibility:
if hidden_size % div == 0:
return div
# hidden_size is an odd number
return 1
def _get_default_config(op_type: str, batch: int, hidden_size: int):
if op_type == "expand":
return {
"BLOCK_N": 256,
"SPLIT_N": _check_divisibility(hidden_size),
"num_warps": 8
}
else:
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
def get_lora_op_configs(op_type: str, batch: int,
hidden_size: int) -> Dict[str, int]:
"""Inspired by `fused_moe_kernel`
The return value will be a dictionary mapping an irregular grid of batch
sizes and hidden_size to configurations of the bgmv-related kernel.
NOTE: It currently only supports the default configuration. We plan to
generate optimal configurations for different hardware in the future using
scripts similar to `benchmark_moe.py`.
"""
config = _get_op_configs(op_type, batch, hidden_size)
if not config:
config = _get_default_config(op_type, batch, hidden_size)
return config
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
__all__ = [
"v1_expand",
"v1_shrink",
"V1KernelMeta",
]

View File

@ -18,7 +18,7 @@ from vllm.utils import direct_register_custom_op
@triton.jit
def _lora_expand_kernel(
def _v1_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
@ -125,7 +125,7 @@ def _lora_expand_kernel(
@torch.inference_mode()
def _lora_expand(
def _v1_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
@ -216,7 +216,7 @@ def _lora_expand(
MAX_LORAS,
)
_lora_expand_kernel[grid](
_v1_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
@ -254,7 +254,7 @@ def _lora_expand(
return
def _lora_expand_fake(
def _v1_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
@ -271,12 +271,12 @@ def _lora_expand_fake(
try:
direct_register_custom_op(
op_name="lora_expand",
op_func=_lora_expand,
op_name="v1_expand",
op_func=_v1_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
fake_impl=_v1_expand_fake,
)
lora_expand = torch.ops.vllm.lora_expand
v1_expand = torch.ops.vllm.v1_expand
except AttributeError:
lora_expand = _lora_expand
v1_expand = _v1_expand

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
LoRA kernels metadata preparation utilities.
V1 LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
@ -10,7 +10,7 @@ import torch
@dataclass
class LoRAKernelMeta:
class V1KernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
@ -19,7 +19,7 @@ class LoRAKernelMeta:
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta":
device: Union[torch.device, str]) -> "V1KernelMeta":
token_lora_mapping = torch.empty(max_num_tokens,
dtype=torch.int32,
@ -47,7 +47,7 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
return LoRAKernelMeta(
return V1KernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
@ -105,7 +105,7 @@ class LoRAKernelMeta:
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the lora_shrink/lora_expand function call.
unpacked directly during the v1_shrink/v1_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward

View File

@ -18,15 +18,15 @@ from vllm.utils import direct_register_custom_op
@triton.jit
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling,
input_d0_stride, input_d1_stride, lora_d0_stride,
lora_d1_stride, lora_d2_stride, output_d0_stride,
output_d1_stride, output_d2_stride,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling, input_d0_stride,
input_d1_stride, lora_d0_stride, lora_d1_stride,
lora_d2_stride, output_d0_stride, output_d1_stride,
output_d2_stride, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
@ -96,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
@torch.inference_mode()
def _lora_shrink(
def _v1_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
@ -174,7 +174,7 @@ def _lora_shrink(
MAX_LORAS,
)
_lora_shrink_kernel[grid](
_v1_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
@ -209,7 +209,7 @@ def _lora_shrink(
return
def _lora_shrink_fake(
def _v1_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
@ -225,12 +225,12 @@ def _lora_shrink_fake(
try:
direct_register_custom_op(
op_name="lora_shrink",
op_func=_lora_shrink,
op_name="v1_shrink",
op_func=_v1_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
fake_impl=_v1_shrink_fake,
)
lora_shrink = torch.ops.vllm.lora_shrink
v1_shrink = torch.ops.vllm.v1_shrink
except AttributeError:
lora_shrink = _lora_shrink
v1_shrink = _v1_shrink

View File

@ -10,12 +10,20 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch
import vllm.envs as env
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
lora_shrink)
if env.VLLM_USE_V1:
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
v1_shrink)
else:
from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.triton_ops import bgmv_shrink
from vllm.lora.ops.triton_ops import sgmv_expand
from vllm.lora.ops.triton_ops import sgmv_shrink
from .punica_base import PunicaWrapperBase
@ -24,8 +32,57 @@ if TYPE_CHECKING:
from vllm.lora.models import LongContextLoRAContext
class V1KernelMixin:
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
max_batches: int, device: Union[torch.device, str]):
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_num_batched_tokens,
device=device)
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_batches,
device=device)
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
sampler_indices: torch.Tensor):
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
def _v1_apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
v1_shrink(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
scale,
)
def _v1_apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
v1_expand(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
offset_start=offset_start,
add_inputs=add_inputs,
)
@final
class PunicaWrapperGPU(PunicaWrapperBase):
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
@ -39,12 +96,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.max_loras = kwargs['max_loras']
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_num_batched_tokens,
device=device)
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_batches,
device=device)
if env.VLLM_USE_V1:
self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
max_batches, device)
def update_metadata(
self,
@ -56,18 +110,83 @@ class PunicaWrapperGPU(PunicaWrapperBase):
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if env.VLLM_USE_V1:
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
self._v1_prepare_metadata_tensors(self.token_lora_indices,
self.sampler_indices)
else:
# Forward to base class update_metadata
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
max_loras, vocab_size,
extra_vocab_size,
long_lora_context, **kwargs)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def _apply_shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
...], scale: float, **kwargs):
def _apply_shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _apply_expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
offset_start=offset_start,
add_inputs=add_inputs,
)
def _apply_expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
@ -80,24 +199,33 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
lora_shrink(
x,
lora_a_stacked,
y,
*self.token_mapping_meta.meta_args(x.size(0)),
scale,
)
if env.VLLM_USE_V1:
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore
else:
if self.is_prefill:
# NOTE fused kernel
self._apply_shrink_prefill(
y, # type: ignore
x,
lora_a_stacked,
scale)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink_decode(y[slice_idx], x,
lora_a_stacked[slice_idx], scale)
def add_expand(self,
y: torch.Tensor,
x: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
@ -116,7 +244,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
@ -131,19 +259,37 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked)
assert x.ndim == 3
assert x.size(0) == len(output_slices)
num_tokens = x.size(1) # first dimension is the num slices
lora_expand(
x,
lora_b_stacked,
y,
*self.token_mapping_meta.meta_args(num_tokens),
offset_start=offset_start,
add_inputs=True,
)
if env.VLLM_USE_V1:
# TODO (varun): Profile with add_inputs = False. i.e. move the
# addition out of the kernel
self._v1_apply_expand(
y,
x, # type: ignore
lora_b_stacked,
offset_start,
add_inputs=True)
else:
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(
y,
x, # type: ignore
lora_b_stacked,
offset_start,
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand_decode(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
@ -165,14 +311,24 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs (bool): Default to True.
"""
lora_expand(
x.unsqueeze(dim=0),
(lora_b_stacked, ),
y,
*self.token_mapping_meta.meta_args(x.size(0)),
offset_start=0,
add_inputs=add_inputs,
)
if env.VLLM_USE_V1:
self._v1_apply_expand(y,
x.unsqueeze(dim=0), (lora_b_stacked, ),
offset_start=0,
add_inputs=add_inputs)
else:
if self.is_prefill:
sgmv_expand(
x.unsqueeze(dim=0),
(lora_b_stacked, ),
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
@ -183,7 +339,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[torch.Tensor] = None,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
@ -205,7 +361,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
@ -275,11 +431,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
dtype=torch.float32,
device=x.device)
lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(x.size(0)), scale)
if env.VLLM_USE_V1:
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y,
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
add_inputs=True)
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y,
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
add_inputs=True)
else:
# V0 LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)

View File

@ -9,6 +9,7 @@ from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -25,7 +26,7 @@ def maybe_backend_fallback(
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `no-fallback` option is specified."""
if guided_params.no_fallback():
raise ValueError(message)
@ -52,12 +53,19 @@ def maybe_backend_fallback(
if guided_params.backend_name == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
fallback_or_error(guided_params,
"xgrammar is only supported on x86 CPUs.",
"outlines")
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):

View File

@ -9,11 +9,13 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, List
import torch
from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
xgr_installed = True
except ImportError:
xgr_installed = False
@ -33,6 +35,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
@ -49,8 +52,18 @@ def get_local_xgrammar_guided_decoding_logits_processor(
@dataclass(frozen=True)
class TokenizerData:
"""Immutable container for cached tokenizer data."""
metadata: str
encoded_vocab: list[str] = field(default_factory=list)
stop_token_ids: list[int] | None = None
# These fields are mutually exclusive: `backend_str` is used to create a
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
# used within the constructor of TokenizeInfo
backend_str: str | None = None
vocab_type: xgr.VocabType | None = None
def __post_init__(self):
# Check for mutual exclusive
assert not (self.backend_str and self.vocab_type), \
"backend_str and vocab_type are mutual exclusive"
class TokenizerDataCache:
@ -58,52 +71,46 @@ class TokenizerDataCache:
_cache: dict[int, TokenizerData] = {}
@classmethod
def get_tokenizer_data(
cls,
tokenizer: PreTrainedTokenizer,
/,
*,
tokenizer_hash: int,
vocab_size: int,
) -> TokenizerData:
def get_tokenizer_data(cls,
tokenizer: PreTrainedTokenizer) -> TokenizerData:
tokenizer_hash = hash(tokenizer)
if tokenizer_hash not in cls._cache:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
# NOTE: We will need to use lm_head's vocab_size
# to determine correct special_token_ids for this tokenizer.
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
vocab_size=vocab_size,
)
metadata = json.loads(tokenizer_info.dump_metadata())
# Vendored from xgrammar logic to get encoded_vocab
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
# Vendored from xgrammar logic since we cannot pickle the tokenizer
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
try:
vocab_dict = tokenizer.get_vocab()
encoded_vocab = [
token for token, _ in sorted(tokenizer.get_vocab().items(),
key=lambda x: x[1])
]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
# maintain tokenizer's indexing
encoded_vocab = [""] * tokenizer_info.vocab_size
for token, idx in vocab_dict.items():
if idx < tokenizer_info.vocab_size:
encoded_vocab[idx] = token
stop_token_ids = None
backend_str = ""
vocab_type = xgr.VocabType.RAW
if isinstance(tokenizer, MistralTokenizer):
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
vocab_type = None
elif isinstance(tokenizer, MistralTokenizer):
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
metadata.update({
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
"add_prefix_space": True
})
vocab_type = xgr.VocabType.BYTE_FALLBACK
cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
metadata=json.dumps(metadata),
)
stop_token_ids=stop_token_ids,
backend_str=backend_str,
vocab_type=vocab_type)
return cls._cache[tokenizer_hash]
@ -122,15 +129,30 @@ class GrammarCompilerCache:
cache_key = str(config.tokenizer_hash)
if cache_key not in cls._cache:
assert config.tokenizer_data is not None
assert config.tokenizer_data.encoded_vocab is not None
config_data = config.tokenizer_data
# In TokenizerDataCache.get_tokenizer_data, a serializable
# tokenizer_data is created and cached. This data is used to build
# a tokenizer_info and create an xgrammar compiler.
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata,
)
# - If tokenizer_data has backend_str set, use
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
# - Otherwise, use the default constructor with vocab_type.
# - xgr_core.TokenizerInfo.from_huggingface !=
# xgr.TokenizerInfo.from_huggingface.
if config_data.backend_str:
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config_data.encoded_vocab, config_data.backend_str,
config.vocab_size, config_data.stop_token_ids))
else:
tokenizer_info = xgr.TokenizerInfo(
config_data.encoded_vocab,
config_data.vocab_type,
vocab_size=config.vocab_size,
stop_token_ids=config_data.stop_token_ids)
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)
@ -141,12 +163,13 @@ class GrammarCompilerCache:
class GrammarConfig:
"""Serializable configuration for grammar compilation"""
tokenizer_hash: int
tokenizer_data: TokenizerData
vocab_size: int
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
max_threads: int = 8
tokenizer_data: TokenizerData | None = None
@classmethod
def from_guided_params(cls,
@ -156,11 +179,7 @@ class GrammarConfig:
max_threads: int = 8) -> GrammarConfig:
tokenizer_hash = hash(tokenizer)
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
tokenizer,
tokenizer_hash=tokenizer_hash,
vocab_size=model_config.hf_text_config.vocab_size,
)
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
if guided_params.json:
if not isinstance(guided_params.json, str):
@ -199,6 +218,7 @@ class GrammarConfig:
raise ValueError(str(err)) from err
return cls(json_str=json_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
@ -226,12 +246,14 @@ class GrammarConfig:
raise ValueError(str(err)) from err
return cls(grammar_str=grammar_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data)
elif guided_params.json_object:
return cls(
json_object=True,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
@ -245,6 +267,7 @@ class GrammarConfig:
return cls(
grammar_str=choice_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
@ -268,13 +291,6 @@ class GrammarConfig:
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
@staticmethod
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
return xgr.TokenizerInfo.from_vocab_and_metadata(
encoded_vocab=tokenizer_data.encoded_vocab,
metadata=tokenizer_data.metadata,
)
@dataclass
class XGrammarLogitsProcessor:
@ -283,16 +299,11 @@ class XGrammarLogitsProcessor:
reasoner: Reasoner | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
token_bitmask: torch.Tensor = None # type: ignore[assignment]
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = field(default=1)
prefilled: bool = field(default=False)
def __post_init__(self):
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config, 'reasoner': self.reasoner}
@ -300,8 +311,6 @@ class XGrammarLogitsProcessor:
self.config = state['config']
self.reasoner = state['reasoner']
self.tokenizer_info = GrammarConfig.tokenizer_info(
self.config.tokenizer_data)
self.ctx = None
self.matchers = []
self.batch_size = 1
@ -343,7 +352,7 @@ class XGrammarLogitsProcessor:
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.tokenizer_info.vocab_size)
self.batch_size, self.config.vocab_size)
if not self.prefilled:
# Have not sampled a token yet

View File

@ -245,6 +245,7 @@ class MambaMixer2(CustomOp):
assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
(
"If tensor parallel world size does not divide num_heads, "

View File

@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
@ -139,10 +140,6 @@ def _fused_moe_gguf(
qweight_type2: int,
act,
) -> torch.Tensor:
# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size)
out_hidden_states = torch.empty_like(x)
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
num_tokens, _ = x.shape

View File

@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
@ -565,11 +565,12 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs(
type="pixel_values",
@ -577,11 +578,12 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
image_embeds = flatten_bn(image_embeds, concat=True)
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs(
type="image_embeds",

View File

@ -39,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
@ -972,11 +972,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return ChameleonImagePixelInputs(
type="pixel_values",

View File

@ -478,7 +478,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
flatten_bn(images_spatial_crop, concat=True)))
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

View File

@ -25,7 +25,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsV0Only)
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -374,7 +374,7 @@ class Gemma3MultiModalProjector(nn.Module):
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA, SupportsV0Only):
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -578,7 +578,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

View File

@ -838,7 +838,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
@ -856,9 +856,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(pixel_values_flat)}")
assert isinstance(image_num_patches, (torch.Tensor, list))
return InternVLImagePixelInputs(
type="pixel_values",

View File

@ -36,6 +36,8 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class JambaMoE(nn.Module):

View File

@ -349,18 +349,21 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
if pixel_values is None:
return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}")
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values_videos,
data=pixel_values,
)
def _select_image_features(self, image_features: torch.Tensor, *,

View File

@ -574,7 +574,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None:
return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
if not (is_list_of(pixel_values_videos,
torch.Tensor) # different shape videos
or isinstance(pixel_values_videos,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}")

View File

@ -111,7 +111,6 @@ class MixtralAttention(nn.Module):
def __init__(
self,
config: MixtralConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
@ -137,9 +136,7 @@ class MixtralAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MixtralConfig has an optional head_dim argument
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@ -203,7 +200,6 @@ class MixtralDecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,

View File

@ -165,7 +165,6 @@ class MixtralAttention(nn.Module):
def __init__(
self,
config: MixtralConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
@ -191,9 +190,7 @@ class MixtralAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MixtralConfig has an optional head_dim argument
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@ -255,7 +252,6 @@ class MixtralDecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,

View File

@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -283,19 +283,17 @@ class Olmo2Model(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
else:
hidden_states = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
@ -339,7 +337,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -348,13 +346,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states

View File

@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@ -270,11 +270,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return PaliGemmaImagePixelInputs(
type="pixel_values",
@ -286,7 +287,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
image_embeds = flatten_bn(image_embeds, concat=True)
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",

View File

@ -73,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
@ -849,10 +849,10 @@ class VisionTransformer(nn.Module):
) -> torch.Tensor:
"""
Args:
images: list of N_img images of variable sizes,
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
@ -935,8 +935,7 @@ class PatchMerger(nn.Module):
# x is (N, vision_encoder_dim)
x = self.permute(x, image_sizes)
# x is (N / spatial_merge_size ** 2,
# vision_encoder_dim * spatial_merge_size ** 2)
# x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
x = self.merging_layer(x)
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)

View File

@ -711,7 +711,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
@ -722,13 +722,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return QwenImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=flatten_bn(image_embeds),
)
return None

View File

@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),

File diff suppressed because it is too large Load Diff

View File

@ -40,7 +40,7 @@ class StructuredOutputManager:
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
self.vocab_size = len(tokenizer.get_vocab())
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98

View File

@ -62,10 +62,9 @@ class LoRAModelRunnerMixin:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
# Set is_prefill to True, so we always use the SGMV kernels on
# non-cuda platforms.
# On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
# Set is_prefill to True, so we always use the SGMV kernels.
# For cuda platforms, we have specialized triton kernels, and
# the cuda path ignores `is_prefill`.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)