[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)
This commit is contained in:
committed by
GitHub
parent
fad5576c58
commit
75acdaa4b6
@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [False, True]
|
||||
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 128, 256]
|
||||
@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
@ -183,6 +185,7 @@ def test_gptq_marlin_gemm(
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -222,8 +225,9 @@ def test_gptq_marlin_gemm(
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
@ -365,12 +369,14 @@ def test_fp8_marlin_gemm(
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -407,8 +413,9 @@ def test_awq_marlin_gemm(
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
has_zp,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user