[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -13,9 +13,10 @@ import torch
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
tokens_bf16 = torch.randn(
|
||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||
|
||||
@ -15,8 +15,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
@ -117,7 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
|
||||
@ -15,9 +15,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm, round_up
|
||||
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
|
||||
per_token_group_cast_to_fp8)
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -170,10 +171,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
|
||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
||||
mm2_out, expert_ids)
|
||||
|
||||
@ -15,8 +15,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -119,10 +117,7 @@ def _fp8_quantize(
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
if is_blackwell_deep_gemm_used():
|
||||
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
|
||||
else:
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||
|
||||
return A, A_scale
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
||||
@ -49,7 +49,6 @@ if not has_deep_gemm():
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_per_token_cast_impl: Callable[..., Any] | None = None
|
||||
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||
else:
|
||||
_dg = importlib.import_module("deep_gemm") # type: ignore
|
||||
@ -74,12 +73,9 @@ else:
|
||||
try:
|
||||
_math_mod = importlib.import_module(
|
||||
"deep_gemm.utils.math") # type: ignore
|
||||
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
|
||||
None)
|
||||
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
|
||||
None)
|
||||
except ModuleNotFoundError:
|
||||
_per_token_cast_impl = None
|
||||
_per_block_cast_impl = None
|
||||
|
||||
|
||||
@ -101,22 +97,6 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
return _grouped_masked_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
|
||||
"""Wrapper for token-wise FP8 quantisation.
|
||||
|
||||
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
|
||||
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
|
||||
"""
|
||||
|
||||
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
assert group_size == 128, "group_size must be 128 for deepgemm"
|
||||
return _per_token_cast_impl(x)
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8 as _ptg)
|
||||
return _ptg(x, group_size, *args, **kwargs)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x, *args, **kwargs):
|
||||
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
return _per_block_cast_impl(x)
|
||||
@ -146,7 +126,6 @@ __all__ = [
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"per_token_group_cast_to_fp8",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_blackwell_deep_gemm_used",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user