[Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm (#26092)
This commit is contained in:
@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, int64_t bit);
|
||||
bool use_exllama, bool use_v2_format, int64_t bit);
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
|
||||
@ -185,7 +185,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
|
||||
const uint32_t*, const half*,
|
||||
half*, const int, const int,
|
||||
const int, const int,
|
||||
const int*);
|
||||
const bool, const int*);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
@ -193,12 +193,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
|
||||
const int size_m, const int size_n, const int size_k, const int groups,
|
||||
const int* __restrict__ b_q_perm) {
|
||||
const bool use_v2_format, const int* __restrict__ b_q_perm) {
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
@ -256,10 +259,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
|
||||
|
||||
// Column result
|
||||
float block_c[m_count][4] = {};
|
||||
@ -272,10 +275,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -329,12 +332,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
|
||||
const int size_m, const int size_n, const int size_k, const int groups,
|
||||
const int* __restrict__ b_q_perm) {
|
||||
const bool use_v2_format, const int* __restrict__ b_q_perm) {
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
@ -409,10 +415,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
|
||||
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
|
||||
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
|
||||
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
|
||||
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
|
||||
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
|
||||
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
|
||||
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++) {
|
||||
@ -448,12 +454,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
|
||||
const int size_m, const int size_n, const int size_k, const int groups,
|
||||
const int* __restrict__ b_q_perm) {
|
||||
const bool use_v2_format, const int* __restrict__ b_q_perm) {
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
@ -534,13 +543,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
|
||||
size_n, zeros[0] + 1);
|
||||
size_n, zeros[0] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
|
||||
size_n, zeros[1] + 1);
|
||||
size_n, zeros[1] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
|
||||
size_n, zeros[2] + 1);
|
||||
size_n, zeros[2] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
|
||||
size_n, zeros[3] + 1);
|
||||
size_n, zeros[3] + zero_offset);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++) {
|
||||
@ -574,12 +583,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
|
||||
const int size_m, const int size_n, const int size_k, const int groups,
|
||||
const int* __restrict__ b_q_perm) {
|
||||
const bool use_v2_format, const int* __restrict__ b_q_perm) {
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
@ -658,13 +670,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
|
||||
zeros[0] + 1);
|
||||
zeros[0] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
|
||||
zeros[1] + 1);
|
||||
zeros[1] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
|
||||
zeros[2] + 1);
|
||||
zeros[2] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
|
||||
zeros[3] + 1);
|
||||
zeros[3] + zero_offset);
|
||||
|
||||
for (int m = 0; m < m_count; m++) {
|
||||
block_c[m][0] =
|
||||
@ -730,7 +742,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales, const int* b_q_perm,
|
||||
half* c, int size_m, int size_n, int size_k,
|
||||
int m_count, int groups, int bit) {
|
||||
int m_count, int groups, bool use_v2_format,
|
||||
int bit) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
@ -743,20 +756,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
|
||||
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros,
|
||||
b_gptq_scales, c, size_m, size_n,
|
||||
size_k, groups, b_q_perm);
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
|
||||
groups, use_v2_format, b_q_perm);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_exllama_8bit_kernel(
|
||||
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
|
||||
const int groups, half* __restrict__ b) {
|
||||
const int groups, const bool use_v2_format, half* __restrict__ b) {
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
@ -812,13 +828,13 @@ __global__ void reconstruct_exllama_8bit_kernel(
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
|
||||
zeros[0] + 1);
|
||||
zeros[0] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
|
||||
zeros[1] + 1);
|
||||
zeros[1] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
|
||||
zeros[2] + 1);
|
||||
zeros[2] + zero_offset);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
|
||||
zeros[3] + 1);
|
||||
zeros[3] + zero_offset);
|
||||
|
||||
// half* dqh = (half*)dq;
|
||||
if (b_q_perm) {
|
||||
@ -849,11 +865,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
||||
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
|
||||
const int groups, half* __restrict__ b) {
|
||||
const int groups, const bool use_v2_format, half* __restrict__ b) {
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
@ -888,10 +907,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -904,10 +923,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++) {
|
||||
@ -954,11 +973,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
|
||||
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
|
||||
const int groups, half* __restrict__ b) {
|
||||
const int groups, const bool use_v2_format, half* __restrict__ b) {
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
@ -1016,13 +1038,13 @@ __global__ void reconstruct_exllama_3bit_kernel(
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
|
||||
size_n, zeros[0] + 1);
|
||||
size_n, zeros[0] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
|
||||
size_n, zeros[1] + 1);
|
||||
size_n, zeros[1] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
|
||||
size_n, zeros[2] + 1);
|
||||
size_n, zeros[2] + zero_offset);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
|
||||
size_n, zeros[3] + 1);
|
||||
size_n, zeros[3] + zero_offset);
|
||||
|
||||
if (b_q_perm) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
@ -1052,11 +1074,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
|
||||
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
|
||||
const int groups, half* __restrict__ b) {
|
||||
const int groups, const bool use_v2_format, half* __restrict__ b) {
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
@ -1108,10 +1133,10 @@ __global__ void reconstruct_exllama_2bit_kernel(
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
|
||||
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
|
||||
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
|
||||
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
|
||||
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
|
||||
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
|
||||
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
|
||||
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
|
||||
|
||||
b_ptr += size_n;
|
||||
// half* dqh = (half*)dq;
|
||||
@ -1143,7 +1168,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales, const int* b_q_perm,
|
||||
half* out, int height, int width, int groups,
|
||||
int bit) {
|
||||
bool use_v2_format, int bit) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
@ -1162,14 +1187,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
|
||||
out);
|
||||
use_v2_format, out);
|
||||
}
|
||||
|
||||
__global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
|
||||
half* __restrict__ mul, const half* __restrict__ scales,
|
||||
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
|
||||
int batch, int height, int width) {
|
||||
int batch, int height, int width, bool use_v2_format) {
|
||||
int zero_width = width / 8;
|
||||
int vec_height = height * 4;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
@ -1179,6 +1204,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
for (int m = 0; m < b_end; ++m) {
|
||||
@ -1223,10 +1251,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
half2 zero = __halves2half2(
|
||||
__hmul(scale_f,
|
||||
__int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
|
||||
1)),
|
||||
__hmul(scale_f2,
|
||||
__int2half_rn(
|
||||
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)));
|
||||
zero_offset)),
|
||||
__hmul(
|
||||
scale_f2,
|
||||
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) -
|
||||
zero_offset)));
|
||||
scales_tmp[tmp_k] = scale;
|
||||
zeros_tmp[tmp_k] = zero;
|
||||
}
|
||||
@ -1268,7 +1297,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
|
||||
half* __restrict__ mul, const half* __restrict__ scales,
|
||||
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
|
||||
int batch, int height, int width) {
|
||||
int batch, int height, int width, bool use_v2_format) {
|
||||
int zero_width = width / 4;
|
||||
int vec_height = height * 2;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
@ -1278,6 +1307,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
for (int m = 0; m < b_end; ++m) {
|
||||
@ -1312,12 +1344,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
half scale_f2 = scales[g2 * width + w];
|
||||
half2 scale = __halves2half2(scale_f, scale_f2);
|
||||
half2 zero = __halves2half2(
|
||||
__hmul(scale_f,
|
||||
__int2half_rn(
|
||||
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
|
||||
__hmul(scale_f2,
|
||||
__int2half_rn(
|
||||
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)));
|
||||
__hmul(scale_f, __int2half_rn(
|
||||
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) -
|
||||
zero_offset)),
|
||||
__hmul(
|
||||
scale_f2,
|
||||
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) -
|
||||
zero_offset)));
|
||||
scales_tmp[tmp_k] = scale;
|
||||
zeros_tmp[tmp_k] = zero;
|
||||
}
|
||||
@ -1355,7 +1388,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales, const int* b_g_idx,
|
||||
half* c, int size_m, int size_n, int size_k,
|
||||
int bit) {
|
||||
bool use_v2_format, int bit) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
@ -1372,17 +1405,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
|
||||
size_m, size_k / 32 * bit, size_n);
|
||||
size_m, size_k / 32 * bit, size_n, use_v2_format);
|
||||
}
|
||||
|
||||
template <class T, int bit>
|
||||
__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
const int height, const int width,
|
||||
const int group,
|
||||
half* __restrict__ out) {
|
||||
__global__ void reconstruct_gptq_kernel(
|
||||
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
|
||||
const int height, const int width, const int group,
|
||||
const bool use_v2_format, half* __restrict__ out) {
|
||||
// Start of block
|
||||
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
@ -1395,6 +1426,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
||||
MatrixView_half w_scales_(w_scales, group, width);
|
||||
T w_zeros_(w_zeros, group, width);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
uint32_t w_read = w[blockIdx.y * width + column];
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
@ -1402,7 +1436,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
||||
for (int s = 0; s < 32; s += bit) {
|
||||
int group = g_idx[row + s / bit];
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
|
||||
half w_item =
|
||||
__hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
|
||||
w_scale);
|
||||
@ -1415,7 +1449,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
|
||||
const int height, const int width, const int group,
|
||||
half* __restrict__ out) {
|
||||
const bool use_v2_format, half* __restrict__ out) {
|
||||
// Start of block
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto row = blockIdx.y * 32;
|
||||
@ -1427,6 +1461,9 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
MatrixView_half w_scales_(w_scales, group, width);
|
||||
MatrixView_q3_row w_zeros_(w_zeros, group, width);
|
||||
|
||||
// GPTQv2 and GPTQv1 handles zero points differently
|
||||
int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
|
||||
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
|
||||
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
|
||||
@ -1436,7 +1473,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
for (int i = 0; i < 32; i += 1) {
|
||||
int group = g_idx[row + i];
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
|
||||
int w_item;
|
||||
if (i == 10) {
|
||||
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
|
||||
@ -1456,7 +1493,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
|
||||
void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales, const int* b_g_idx, half* out,
|
||||
int height, int width, int groups, int bit) {
|
||||
int height, int width, int groups, bool use_v2_format,
|
||||
int bit) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
@ -1476,7 +1514,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
|
||||
b_gptq_qzeros, b_g_idx, height,
|
||||
width, groups, out);
|
||||
width, groups, use_v2_format, out);
|
||||
}
|
||||
|
||||
void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
@ -1484,7 +1522,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales, const int* b_g_idx,
|
||||
half* c, half* temp_dq, int size_m, int size_n,
|
||||
int size_k, int groups, bool use_exllama, int bit) {
|
||||
int size_k, int groups, bool use_exllama,
|
||||
bool use_v2_format, int bit) {
|
||||
bool use_reconstruct;
|
||||
if (use_exllama) {
|
||||
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
|
||||
@ -1498,10 +1537,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
if (use_exllama) {
|
||||
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
temp_dq, size_k, size_n, groups, bit);
|
||||
temp_dq, size_k, size_n, groups, use_v2_format, bit);
|
||||
} else {
|
||||
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
temp_dq, size_k, size_n, groups, bit);
|
||||
temp_dq, size_k, size_n, groups, use_v2_format, bit);
|
||||
}
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
@ -1517,18 +1556,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
if (max_chunks) {
|
||||
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
||||
b_g_idx, c, last_chunk, size_n, size_k,
|
||||
BLOCK_M_SIZE_MAX, groups, bit);
|
||||
BLOCK_M_SIZE_MAX, groups, use_v2_format, bit);
|
||||
}
|
||||
|
||||
if (last_chunk_size) {
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight,
|
||||
b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c + last_chunk * size_n, last_chunk_size,
|
||||
size_n, size_k, last_chunk_size, groups, bit);
|
||||
gemm_half_q_half_cuda_part(
|
||||
a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
||||
b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k,
|
||||
last_chunk_size, groups, use_v2_format, bit);
|
||||
}
|
||||
} else {
|
||||
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c, size_m, size_n, size_k, bit);
|
||||
c, size_m, size_n, size_k, use_v2_format, bit);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1815,7 +1854,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, int64_t bit) {
|
||||
bool use_exllama, bool use_v2_format, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
||||
@ -1833,7 +1872,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
b_gptq_qzeros.size(0), // group number
|
||||
use_exllama, bit);
|
||||
use_exllama, use_v2_format, bit);
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// to prevent the meta function registry.
|
||||
ops.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
|
||||
"use_v2_format, int bit) "
|
||||
"-> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
@ -26,4 +26,10 @@ def test_gptq_gemm_opcheck():
|
||||
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit))
|
||||
# Test both GPTQv1 and GPTQv2 format
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit)
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit)
|
||||
)
|
||||
|
||||
109
tests/quantization/test_gptq_v2.py
Normal file
109
tests/quantization/test_gptq_v2.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests whether vllm correctly load and run gptq_v2 format checkpoints.
|
||||
|
||||
Run `pytest tests/quantization/test_gptq_v2.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format
|
||||
MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"]
|
||||
|
||||
# Generate multiple sequences for testing, because an 1.7B 2-bit model
|
||||
# cannot always generate normal texts.
|
||||
N_SEQ = 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
def test_model_load(vllm_runner, model_id, monkeypatch):
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Only check the default GPTQ linear method (used for 2/3-bit models).
|
||||
# 4/8-bit linear methods like Marlin already support gptq_v2.
|
||||
linear_method_cls = GPTQLinearMethod
|
||||
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
|
||||
|
||||
def check_model(model_id):
|
||||
for name, submodule in model_id.named_modules():
|
||||
# Could check more modules if necessary
|
||||
if name == "model_id.layers.0.self_attn.qkv_proj":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.checkpoint_format == "gptq_v2"
|
||||
assert submodule.quant_method.use_v2_format
|
||||
|
||||
# Just break since currently we only check 1 module
|
||||
break
|
||||
|
||||
# Check if gptq_v2 format is correctly loaded
|
||||
llm.apply_model(check_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
def test_model_inference(vllm_runner, model_id):
|
||||
# Prepare prompt to test the model's generation result.
|
||||
prompt = "What is the meaning of life?"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # If thinking model, set it to false
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
n=N_SEQ,
|
||||
max_tokens=128,
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
top_k=20,
|
||||
min_p=0,
|
||||
presence_penalty=2,
|
||||
)
|
||||
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
|
||||
# Generate a response to verify inference correctness
|
||||
output = llm.generate(text, sampling_params)
|
||||
|
||||
# Make sure the output exists
|
||||
assert output
|
||||
assert output[0][1]
|
||||
assert len(output[0][1]) == N_SEQ
|
||||
|
||||
def has_normal_char_distribution(texts, min_len):
|
||||
for text in texts:
|
||||
# Response too short
|
||||
if len(text) < min_len:
|
||||
return False
|
||||
|
||||
# Basic ratio checks
|
||||
letters = sum(c.isalpha() for c in text)
|
||||
spaces = sum(c.isspace() for c in text)
|
||||
total = len(text)
|
||||
|
||||
letter_ratio = letters / total
|
||||
space_ratio = spaces / total
|
||||
|
||||
# At least 1 normal text should exist within output sequences
|
||||
# Normal text should be mostly letters with reasonable spacing
|
||||
# Some magic numbers, could be adjusted
|
||||
if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3:
|
||||
return True
|
||||
# No sequence contains normal text, output might be broken
|
||||
return False
|
||||
|
||||
# Apply some simple checks for giberish output
|
||||
# Print the output sequences if failed
|
||||
assert has_normal_char_distribution(output[0][1], 5), output[0][1]
|
||||
@ -451,10 +451,18 @@ def gptq_gemm(
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_exllama: bool,
|
||||
use_v2_format: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit
|
||||
a,
|
||||
b_q_weight,
|
||||
b_gptq_qzeros,
|
||||
b_gptq_scales,
|
||||
b_g_idx,
|
||||
use_exllama,
|
||||
use_v2_format,
|
||||
bit,
|
||||
)
|
||||
|
||||
|
||||
@ -468,6 +476,7 @@ if hasattr(torch.ops._C, "gptq_gemm"):
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_exllama: bool,
|
||||
use_v2_format: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
|
||||
@ -11,6 +11,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -36,6 +37,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
@ -52,6 +55,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
checkpoint_format: str = "",
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
@ -89,12 +93,24 @@ class GPTQConfig(QuantizationConfig):
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
|
||||
# For now, show a warning, since gptq_marlin will be used by default.
|
||||
if self.weight_bits == 4:
|
||||
logger.warning_once(
|
||||
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
|
||||
"Please switch to gptq_marlin or gptq_bitblas."
|
||||
)
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently.
|
||||
# Currently GPTQModel stores v1 format checkpoints by default,
|
||||
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
|
||||
self.checkpoint_format = checkpoint_format
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
@ -102,7 +118,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
|
||||
f"checkpoint_format={self.checkpoint_format})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -137,6 +154,9 @@ class GPTQConfig(QuantizationConfig):
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
checkpoint_format = cls.get_from_keys_or(
|
||||
config, ["checkpoint_format"], default=""
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
@ -145,6 +165,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
dynamic,
|
||||
autoround_version,
|
||||
modules_in_block_to_quantize,
|
||||
checkpoint_format,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
@ -154,6 +175,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
|
||||
# TODO: maybe update this for GPTQv2 format checkpoints
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": self.weight_bits,
|
||||
@ -210,6 +232,9 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently
|
||||
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -351,6 +376,8 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
|
||||
# and require different gemm kernels.
|
||||
output = ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
@ -358,6 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.use_v2_format,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
|
||||
@ -145,10 +145,15 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
|
||||
# However, the MPLinearLayerConfig doesn't contain format info.
|
||||
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
|
||||
use_v2_format = False
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user