[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-08-23 06:09:05 +08:00
committed by GitHub
parent cc7ae5e7ca
commit 24d0c9e6ed
27 changed files with 596 additions and 200 deletions

View File

@ -6,7 +6,11 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
if q_quant_dtype != kv_quant_dtype:
pytest.skip("Skipped mixed QKV dtypes for prefill")
max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2