[Kernel] Add per-tensor and per-token AZP epilogues (#5941)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@ -166,13 +166,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
Reference in New Issue
Block a user