Add: Support for Sparse24Bitmask Compressed Models

This commit is contained in:
Rahul Tuli
2025-02-05 15:30:43 -06:00
committed by GitHub
parent af8486de49
commit 3b2005e1db
4 changed files with 503 additions and 112 deletions

View File

@ -3,6 +3,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from typing import Optional
import pytest
@ -22,12 +23,30 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize(
"model_args",
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560, True),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560, True),
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
QuantizationType.INT, 2560, False)])
[
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
"tensor",
QuantizationType.INT,
2560,
True,
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
"channel",
QuantizationType.INT,
2560,
True,
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
"tensor",
QuantizationType.INT,
2560,
False,
),
],
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert output
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
],
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
example_prompts, model_path,
max_tokens, num_logprobs):
def test_compressed_tensors_w8a8_logprobs(
hf_runner,
vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
):
dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501
if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model:
@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
assert output
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel"),
])
@pytest.mark.parametrize(
"model_args",
[
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
"channel",
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel",
),
],
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@pytest.mark.parametrize(
"wNa16_args",
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
[
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
],
)
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner):
CompressedTensorsLinearMethod)
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
)
assert qkv_proj.input_scale.dtype is torch.float32
@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner):
assert output
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="dense"):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").format == format
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
"token"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel", "tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
"tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
"channel",
"token",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
assert output
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
"tensor"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="2of4 Sparse is not yet supported on this GPU type.")
reason="2of4 Sparse is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output