[Kernel] Adding bias epilogue support for cutlass_scaled_mm (#5560)

Co-authored-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Luka Govedič
2024-06-26 11:16:00 -04:00
committed by GitHub
parent 6984c02a27
commit 5bfd1bbc98
8 changed files with 385 additions and 136 deletions

View File

@ -32,6 +32,7 @@ def cutlass_fp8_gemm_helper(m: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
@ -46,10 +47,17 @@ def cutlass_fp8_gemm_helper(m: int,
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(out_dtype)
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
@ -59,6 +67,7 @@ def cutlass_int8_gemm_helper(m: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
@ -74,11 +83,17 @@ def cutlass_int8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(dtype=out_dtype)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
@ -87,11 +102,12 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@ -99,49 +115,72 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
torch.bfloat16, device)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
torch.bfloat16, device)
bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=torch.bfloat16,
device=device)
# For the following two tests:
@ -151,20 +190,25 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
bias)
# Test working with a subset of A and B