[Kernel] Add ModelOpt FP4 Checkpoint Support (#12520)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-03-11 22:13:11 -07:00
committed by GitHub
parent 5c538c37b2
commit debd6bbf09
10 changed files with 388 additions and 30 deletions

View File

@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif
// Quantized GEMM for GPTQ.