Add grouped b2b GEMM (#970)

This commit is contained in:
Jack Kosaian
2023-06-05 17:16:57 -04:00
committed by GitHub
parent fde824af21
commit 87349d3496
15 changed files with 1644 additions and 107 deletions

View File

@ -290,7 +290,6 @@ public:
int available_sm_count=-1) {
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {

View File

@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
auto n = tiled_shape.n();
// Thresholds picked so that it doesn't cause too many no-op CTAs
if (N >= 8 && n >= 6)
@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
auto n = tiled_shape.n();
// Thresholds picked so that it doesn't cause too many no-op CTAs
if (N >= 8 && n >= 6)
@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}