|
|
|
|
@ -1,14 +1,10 @@
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
|
#include <torch/all.h>
|
|
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
#include "../per_token_group_quant_8bit.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
|
|
|
|
|
#include "../../dispatch_utils.h"
|
|
|
|
|
#include "../vectorization_utils.cuh"
|
|
|
|
|
#include "../../../dispatch_utils.h"
|
|
|
|
|
#include "../../vectorization_utils.cuh"
|
|
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
@ -25,19 +21,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
|
|
|
static constexpr auto i8_max =
|
|
|
|
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
|
|
|
|
|
|
|
|
|
// To match the rounding mode of CUDA, we use nearbyint.
|
|
|
|
|
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
|
|
|
|
// If that changes in the future, we may need to set the rounding mode
|
|
|
|
|
// explicitly, either at runtime or compile time.
|
|
|
|
|
float dst = std::nearbyint(x);
|
|
|
|
|
|
|
|
|
|
// saturate
|
|
|
|
|
|
|
|
|
|
// See https://github.com/pytorch/pytorch/issues/127666
|
|
|
|
|
// See https://github.com/llvm/llvm-project/issues/95183
|
|
|
|
|
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
|
|
|
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
|
|
|
// dst = std::clamp(dst, i8_min, i8_max);
|
|
|
|
|
// Replace std::clamp due to hip-clang issues
|
|
|
|
|
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
|
|
|
|
return static_cast<int8_t>(dst);
|
|
|
|
|
#else
|
|
|
|
|
@ -50,26 +36,16 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
|
|
|
|
|
|
|
|
static inline __device__ int32_t float_to_int32_rn(float x) {
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
// int32_max is not exactly representable as float.
|
|
|
|
|
// Therefore, we need to be careful and manually return int32_max on overflow.
|
|
|
|
|
// For symmetry, we also do the same for int32_min, even though it is exactly
|
|
|
|
|
// representable as float and the conversion should be exact.
|
|
|
|
|
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
|
|
|
|
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
|
|
|
|
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
|
|
|
|
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
|
|
|
|
|
|
|
|
|
// To match the rounding mode of CUDA, we use nearbyint.
|
|
|
|
|
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
|
|
|
|
// If that changes in the future, we may need to set the rounding mode
|
|
|
|
|
// explicitly, either at runtime or compile time.
|
|
|
|
|
float dst = std::nearbyint(x);
|
|
|
|
|
|
|
|
|
|
// saturate on the higher end.
|
|
|
|
|
if (dst >= i32_max_f) {
|
|
|
|
|
return i32_max;
|
|
|
|
|
}
|
|
|
|
|
// saturate on the lower end.
|
|
|
|
|
if (dst <= i32_min_f) {
|
|
|
|
|
return i32_min;
|
|
|
|
|
}
|
|
|
|
|
@ -90,13 +66,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
|
|
|
|
static constexpr auto i8_max =
|
|
|
|
|
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
|
|
|
|
|
|
|
|
|
// saturate
|
|
|
|
|
|
|
|
|
|
// See https://github.com/pytorch/pytorch/issues/127666
|
|
|
|
|
// See https://github.com/llvm/llvm-project/issues/95183
|
|
|
|
|
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
|
|
|
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
|
|
|
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
|
|
|
|
// Replace std::clamp due to hip-clang issues
|
|
|
|
|
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
|
|
|
|
return static_cast<int8_t>(dst);
|
|
|
|
|
#else
|
|
|
|
|
@ -118,7 +88,6 @@ __global__ void static_scaled_int8_quant_kernel(
|
|
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
|
const float scale = *scale_ptr;
|
|
|
|
|
|
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
|
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
|
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
|
|
|
|
|
|
@ -140,7 +109,6 @@ __global__ void static_scaled_int8_azp_quant_kernel(
|
|
|
|
|
const azp_t azp = *azp_ptr;
|
|
|
|
|
const float inv_s = 1.0f / scale;
|
|
|
|
|
|
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
|
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
|
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
|
|
|
|
|
|
@ -160,11 +128,9 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
|
|
|
|
const int stride = blockDim.x;
|
|
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
|
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
|
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
|
|
|
|
|
|
// calculate for absmax
|
|
|
|
|
float thread_max = 0.f;
|
|
|
|
|
vectorize_read_with_alignment<16>(
|
|
|
|
|
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
|
|
|
|
@ -183,7 +149,6 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
|
|
|
|
|
|
|
|
|
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
|
|
|
|
|
|
|
|
|
|
// 2. quantize
|
|
|
|
|
vectorize_with_alignment<16>(
|
|
|
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
|
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
|
|
|
@ -201,14 +166,12 @@ struct MinMax {
|
|
|
|
|
|
|
|
|
|
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
|
|
|
|
|
|
|
|
|
|
// add a value to the MinMax
|
|
|
|
|
__host__ __device__ MinMax& operator+=(float v) {
|
|
|
|
|
min = fminf(min, v);
|
|
|
|
|
max = fmaxf(max, v);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// merge two MinMax objects
|
|
|
|
|
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
|
|
|
|
min = fminf(min, other.min);
|
|
|
|
|
max = fmaxf(max, other.max);
|
|
|
|
|
@ -231,11 +194,9 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|
|
|
|
const int stride = blockDim.x;
|
|
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
|
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
|
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
|
|
|
|
|
|
// 1. calculate min & max
|
|
|
|
|
MinMax thread_mm;
|
|
|
|
|
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
|
|
|
|
[&] __device__(const scalar_t& src) {
|
|
|
|
|
@ -257,7 +218,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|
|
|
|
__shared__ azp_t azp_sh;
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
float s = (mm.max - mm.min) / 255.f;
|
|
|
|
|
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
|
|
|
|
float zp = nearbyintf(-128.f - mm.min / s);
|
|
|
|
|
scale_sh = s;
|
|
|
|
|
azp_sh = azp_t(zp);
|
|
|
|
|
scale_out[blockIdx.x] = s;
|
|
|
|
|
@ -268,7 +229,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|
|
|
|
const float inv_s = 1.f / scale_sh;
|
|
|
|
|
const azp_t azp = azp_sh;
|
|
|
|
|
|
|
|
|
|
// 2. quantize
|
|
|
|
|
vectorize_with_alignment<16>(
|
|
|
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
|
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
|
|
|
@ -339,14 +299,4 @@ void dynamic_scaled_int8_quant(
|
|
|
|
|
hidden_size);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
void per_token_group_quant_int8(const torch::Tensor& input,
|
|
|
|
|
torch::Tensor& output_q,
|
|
|
|
|
torch::Tensor& output_s, int64_t group_size,
|
|
|
|
|
double eps, double int8_min, double int8_max) {
|
|
|
|
|
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
|
|
|
|
int8_min, int8_max);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|