From 7b2f28deba3ce0ad773611f1612f9fc092b0e923 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Wed, 14 May 2025 00:13:56 -0500 Subject: [PATCH] [AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082) Signed-off-by: charlifu --- .buildkite/test-pipeline.yaml | 1 + csrc/quantization/activation_kernels.cu | 3 ++- tests/compile/test_silu_mul_quant_fusion.py | 6 +++--- tests/kernels/quantization/test_rocm_skinny_gemms.py | 5 +++-- tests/kernels/test_fused_quant_activation.py | 5 +++-- vllm/compilation/activation_quant_fusion.py | 3 ++- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d46459eaee..1040d1e1b8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -309,6 +309,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index acc3d67220..67e9149c13 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel( void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || + out.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.size(-1) % 2 == 0); diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index f87f175acd..9eae48d60f 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -27,8 +27,8 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], + reason="Only test on CUDA and ROCm") def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_reshape=True)) + pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(fusion_pass) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 76d3316908..c7eee89989 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("m", M + [28672]) # m >= 16 @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8") def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index fa84ad74cd..faa8d49ce4 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -5,9 +5,10 @@ import torch import vllm._custom_ops as ops from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] -QUANT_DTYPES = [torch.float8_e4m3fn] +QUANT_DTYPES = [current_platform.fp8_dtype()] NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] @@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: out_shape = (x.shape[0], x.shape[1] // 2) out = torch.empty(out_shape, - dtype=torch.torch.float8_e4m3fn, + dtype=current_platform.fp8_dtype(), device=x.device) torch.ops._C.silu_and_mul_quant(out, x, scale) return out diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 1917ed8bbe..dc3e1482e2 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass @@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs): def empty_fp8(*args, **kwargs): - fp8 = torch.float8_e4m3fn + fp8 = current_platform.fp8_dtype() return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")