Add grouped b2b GEMM (#970)
This commit is contained in:
@ -290,7 +290,6 @@ public:
|
||||
int available_sm_count=-1) {
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
auto n = tiled_shape.n();
|
||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if (N >= 8 && n >= 6)
|
||||
@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
auto n = tiled_shape.n();
|
||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if (N >= 8 && n >= 6)
|
||||
@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user