|
|
|
|
@ -13,14 +13,34 @@
|
|
|
|
|
#include "dispatch_utils.h"
|
|
|
|
|
#include "quantization/fp8/common.cuh"
|
|
|
|
|
|
|
|
|
|
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
|
|
|
|
|
#define __HIP__MI300_MI250__
|
|
|
|
|
#if defined(__HIPCC__) && \
|
|
|
|
|
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
|
|
|
|
#define __HIP__GFX9__
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if defined(__HIPCC__) && defined(__gfx942__)
|
|
|
|
|
#define __HIP__MI300__
|
|
|
|
|
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
|
|
|
|
#define __HIP__MI3XX__
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if defined(__gfx950__)
|
|
|
|
|
#define LDS_SIZE 160 * 1024
|
|
|
|
|
#else
|
|
|
|
|
#define LDS_SIZE 64 * 1024
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
int get_lds_size() {
|
|
|
|
|
static bool is_cached = false;
|
|
|
|
|
static int result;
|
|
|
|
|
if (is_cached == false) {
|
|
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
|
|
std::string device_arch = dprops->gcnArchName;
|
|
|
|
|
size_t substring = device_arch.find("gfx95");
|
|
|
|
|
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
|
|
|
|
|
is_cached = true;
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(NDEBUG)
|
|
|
|
|
#undef NDEBUG
|
|
|
|
|
#include <assert.h>
|
|
|
|
|
@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
|
|
|
|
V0 += (s.x + s.y); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
|
|
|
|
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
|
|
|
|
// This version targets cases where A[] fits LDS capacity
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const scalar_t* __restrict__ A, scalar_t* C,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
#if defined(__HIP__MI300__)
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
|
|
|
|
#if defined(__HIP__MI3XX__)
|
|
|
|
|
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
|
|
|
|
#else
|
|
|
|
|
constexpr bool use_mfma = false;
|
|
|
|
|
@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
// Reserving 64 KB of LDS to have 1 WG / CU
|
|
|
|
|
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
|
|
|
|
// Goal is to bring the activation matrix A to the LDS
|
|
|
|
|
// and use it across the lifetime of the work group
|
|
|
|
|
// TODO: When activation matrix is larger than 64 KB
|
|
|
|
|
// then this is not goint to work!
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
__shared__ scalar_t s[1024 * 32];
|
|
|
|
|
__shared__ scalar_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
// Fetch the activation matrix to LDS
|
|
|
|
|
@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// - Then the WG will move to another 8 K elements
|
|
|
|
|
// TODO: Logic below will only work when K is multiple of 8
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, 32 * 1024)) break;
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
m += CuCount * _WvPrGrp * YTILE;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
UNREACHABLE_CODE
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
|
|
|
|
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
|
|
|
|
// This version targets cases where A[] marginally exceeds LDS capacity
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const scalar_t* __restrict__ A, scalar_t* C,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
#if defined(__HIP__MI300__)
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
|
|
|
|
#if defined(__HIP__MI3XX__)
|
|
|
|
|
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
|
|
|
|
#else
|
|
|
|
|
constexpr bool use_mfma = false;
|
|
|
|
|
@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// TODO: When activation matrix is larger than 64 KB
|
|
|
|
|
// then this is not goint to work!
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
__shared__ scalar_t s[1024 * 32];
|
|
|
|
|
__shared__ scalar_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
// Computation of columns that need to be committed to memory!
|
|
|
|
|
@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// - Then the WG will move to another 8 K elements
|
|
|
|
|
// TODO: Logic below will only work when K is multiple of 8
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, 32 * 1024)) break;
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
|
|
|
|
|
|
|
|
|
for (int n = 0; n < N; n++) {
|
|
|
|
|
if (k_ + K * n < 32 * 1024)
|
|
|
|
|
if (k_ + K * n < max_lds_len)
|
|
|
|
|
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
|
|
|
|
else
|
|
|
|
|
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
|
|
|
|
@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
UNREACHABLE_CODE
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
|
|
|
|
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
|
|
|
|
// This version targets big A[] cases, where it is much larger than LDS capacity
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const scalar_t* __restrict__ A, scalar_t* C,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
#if defined(__HIP__MI300__)
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
|
|
|
|
#if defined(__HIP__MI3XX__)
|
|
|
|
|
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
|
|
|
|
#else
|
|
|
|
|
constexpr bool use_mfma = false;
|
|
|
|
|
@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
// Reserving 64 KB of LDS to have 1 WG / CU
|
|
|
|
|
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
|
|
|
|
// Goal is to bring the activation matrix A to the LDS
|
|
|
|
|
// and use it across the lifetime of the work group
|
|
|
|
|
// TODO: When activation matrix is larger than 64 KB
|
|
|
|
|
// then this is not goint to work!
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
__shared__ scalar_t s[1024 * 32];
|
|
|
|
|
__shared__ scalar_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
// Computation of columns that need to be committed to memory!
|
|
|
|
|
@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
#define PCML
|
|
|
|
|
#ifndef PCML
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, 32 * 1024)) break;
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
#define TUC (THRDS * UNRL * A_CHUNK)
|
|
|
|
|
uint32_t kBase = 0;
|
|
|
|
|
// find biggest k size that fits in LDS
|
|
|
|
|
uint32_t kFit = (32 * 1024) / N;
|
|
|
|
|
uint32_t kFit = (max_lds_len) / N;
|
|
|
|
|
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
|
|
|
|
|
// of TUC
|
|
|
|
|
kFit = (kFit % TUC == 0)
|
|
|
|
|
@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N>
|
|
|
|
|
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
UNREACHABLE_CODE
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
|
|
|
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
int mindiv(int N, int div1, int div2) {
|
|
|
|
|
int nPrRnd = div1 * div2;
|
|
|
|
|
@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
|
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
const int max_lds_len = get_lds_size() / 2;
|
|
|
|
|
|
|
|
|
|
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
|
|
|
|
_N) \
|
|
|
|
|
{ \
|
|
|
|
|
dim3 block(64, _WvPrGrp); \
|
|
|
|
|
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
|
|
|
|
|
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
|
|
|
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
|
|
|
|
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
|
|
|
|
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
|
|
|
|
CuCount); \
|
|
|
|
|
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
|
|
|
|
|
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
|
|
|
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
|
|
|
|
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
|
|
|
|
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
|
|
|
|
@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
|
|
|
|
return out_c;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
|
|
|
|
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
|
|
|
|
int A_CHUNK, int UNRL, int N>
|
|
|
|
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
const float* __restrict__ s_A,
|
|
|
|
|
const float* __restrict__ s_B, const int _WvPrGrp,
|
|
|
|
|
const int CuCount) {
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE;
|
|
|
|
|
using scalar8 =
|
|
|
|
|
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
|
|
|
|
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
|
|
|
@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
scalar8 h8;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
__shared__ fp8_t s[1024 * 64];
|
|
|
|
|
__shared__ fp8_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
|
|
|
|
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
m += CuCount * _WvPrGrp * YTILE;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
|
|
|
|
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
|
|
|
|
int A_CHUNK, int UNRL, int N>
|
|
|
|
|
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
|
|
|
|
@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
UNREACHABLE_CODE
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
|
|
|
|
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
|
|
|
|
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
|
|
|
|
int A_CHUNK, int UNRL, int N>
|
|
|
|
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
const fp8_t* __restrict__ A, scalar_t* C,
|
|
|
|
|
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
|
|
|
|
const int _WvPrGrp, const int CuCount) {
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE;
|
|
|
|
|
using scalar8 =
|
|
|
|
|
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
|
|
|
|
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
|
|
|
@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
scalar8 h8;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
__shared__ fp8_t s[1024 * 64];
|
|
|
|
|
__shared__ fp8_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
|
|
|
|
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
|
|
|
if (k_ >= K) break;
|
|
|
|
|
for (int n = 0; n < N; n++) {
|
|
|
|
|
if (k_ + K * n < 64 * 1024)
|
|
|
|
|
if (k_ + K * n < max_lds_len)
|
|
|
|
|
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
|
|
|
|
else
|
|
|
|
|
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
|
|
|
|
@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
m += CuCount * _WvPrGrp * YTILE;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
|
|
|
|
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
|
|
|
|
int A_CHUNK, int UNRL, int N>
|
|
|
|
|
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
|
|
|
|
@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
|
|
|
|
const int CuCount) {
|
|
|
|
|
UNREACHABLE_CODE
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
|
|
|
|
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
|
|
|
|
at::Tensor& scale_a, at::Tensor& scale_b,
|
|
|
|
|
@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
|
|
|
|
dim3 grid(CuCount);
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
const int max_lds_len = get_lds_size();
|
|
|
|
|
|
|
|
|
|
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
|
|
|
|
_N) \
|
|
|
|
|
{ \
|
|
|
|
|
dim3 block(64, _WvPrGrp); \
|
|
|
|
|
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
|
|
|
|
|
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
|
|
|
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
|
|
|
|
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
|
|
|
|
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
|
|
|
|
|