[Perf] Cuda Kernel for Per Token Group Quant (#21083)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-22 10:27:15 -04:00
committed by GitHub
parent 2c8db17cfd
commit 774d0c014b
6 changed files with 285 additions and 4 deletions

View File

@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"