[Kernel] Add w8a8 CUTLASS kernels (#4749)

This commit is contained in:
Tyler Michael Smith
2024-05-16 18:32:50 -04:00
committed by GitHub
parent 8435b207af
commit 2060e93659
10 changed files with 1197 additions and 2 deletions

View File

@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");