[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-25 20:07:07 -04:00
committed by GitHub
parent 41d3082c41
commit 75d29cf4e1
6 changed files with 47 additions and 3 deletions

View File

@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
ops.impl("per_token_group_quant_int8", torch::kCUDA,
&per_token_group_quant_int8);
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "