[Perf] Fix and reapply move apply w8a8 block fp8 linear to class (#25696)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -11,7 +11,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt,
|
||||
@ -91,7 +91,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
|
||||
@ -20,9 +20,11 @@ from vllm.platforms import current_platform
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int) -> None:
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (expected_num_groups, batch_size)
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@ -68,21 +77,23 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
|
||||
# Test with 3D input
|
||||
batch1, batch2, hidden_dim = 4, 8, 512
|
||||
batch1, batch2, hidden_dim = 4, 8, 1024
|
||||
x_3d = torch.randn(
|
||||
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
@ -91,9 +102,10 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
|
||||
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test with 4D input
|
||||
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||
|
||||
Reference in New Issue
Block a user