refactor quant folder

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256
2025-08-07 15:05:05 -07:00
parent 7e3a8dc906
commit f07e10e9bc
5 changed files with 25 additions and 63 deletions

View File

@ -243,7 +243,7 @@ set(VLLM_EXT_SRC
"csrc/sampler.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/8bit/int8/scaled_quant.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
@ -297,7 +297,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu"
"csrc/quantization/fp8/per_token_group_quant.cu")
"csrc/quantization/8bit/fp8/per_token_group_quant.cu"
"csrc/quantization/8bit/int8/per_token_group_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"

View File

@ -8,9 +8,9 @@
#include <torch/all.h>
#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
#include "../../vectorization.cuh"
#include "../../vectorization_utils.cuh"
#include "../../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}
}

View File

@ -0,0 +1,12 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "../per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}

View File

@ -1,14 +1,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include "../per_token_group_quant_8bit.h"
#endif
#include <cmath>
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#include "../../../dispatch_utils.h"
#include "../../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/cub.cuh>
@ -25,19 +21,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
static constexpr auto i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x);
// saturate
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
// Replace std::clamp due to hip-clang issues
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst);
#else
@ -50,26 +36,16 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
static inline __device__ int32_t float_to_int32_rn(float x) {
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
static constexpr auto i32_min_f = static_cast<float>(i32_min);
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
static constexpr auto i32_max_f = static_cast<float>(i32_max);
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x);
// saturate on the higher end.
if (dst >= i32_max_f) {
return i32_max;
}
// saturate on the lower end.
if (dst <= i32_min_f) {
return i32_min;
}
@ -90,13 +66,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static constexpr auto i8_max =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
// Replace std::clamp due to hip-clang issues
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast<int8_t>(dst);
#else
@ -118,7 +88,6 @@ __global__ void static_scaled_int8_quant_kernel(
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
@ -140,7 +109,6 @@ __global__ void static_scaled_int8_azp_quant_kernel(
const azp_t azp = *azp_ptr;
const float inv_s = 1.0f / scale;
// Must be performed using 64-bit math to avoid integer overflow.
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
@ -160,11 +128,9 @@ __global__ void dynamic_scaled_int8_quant_kernel(
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// calculate for absmax
float thread_max = 0.f;
vectorize_read_with_alignment<16>(
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
@ -183,7 +149,6 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
@ -201,14 +166,12 @@ struct MinMax {
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v);
max = fmaxf(max, v);
return *this;
}
// merge two MinMax objects
__host__ __device__ MinMax& operator&=(const MinMax& other) {
min = fminf(min, other.min);
max = fmaxf(max, other.max);
@ -231,11 +194,9 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// 1. calculate min & max
MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) {
@ -257,7 +218,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
__shared__ azp_t azp_sh;
if (tid == 0) {
float s = (mm.max - mm.min) / 255.f;
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
float zp = nearbyintf(-128.f - mm.min / s);
scale_sh = s;
azp_sh = azp_t(zp);
scale_out[blockIdx.x] = s;
@ -268,7 +229,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
@ -339,14 +299,4 @@ void dynamic_scaled_int8_quant(
hidden_size);
}
});
}
#ifndef USE_ROCM
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
#endif
}

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/all.h>
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,