[ROCm] Fixes for GPTQ on ROCm (#2180)
This commit is contained in:
@ -28,6 +28,7 @@ namespace gptq {
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
|
||||
zeros_tmp[tmp_k] = zero;
|
||||
}
|
||||
for (int m = 0; m < b_end; m++) {
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
|
||||
#ifndef USE_ROCM
|
||||
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
|
||||
#else
|
||||
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
|
||||
#endif
|
||||
}
|
||||
i += width;
|
||||
k += 4;
|
||||
|
||||
Reference in New Issue
Block a user