[NVIDIA] Support nvfp4 cutlass gemm (#13571)

This commit is contained in:
Kaixi Hou
2025-02-22 05:24:05 -08:00
committed by GitHub
parent 8db1b9d0a1
commit e109e598c7
7 changed files with 494 additions and 1 deletions

View File

@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"SymInt size_k) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(