Support W4A8 quantization for vllm (#5218)

This commit is contained in:
HandH1998
2024-07-31 21:55:21 +08:00
committed by GitHub
parent c0644cf9ce
commit 6512937de1
15 changed files with 1963 additions and 84 deletions

View File

@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// marlin_qqq_gemm for QQQ.
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(