CUTLASS 2.6 (#298)

CUTLASS 2.6
This commit is contained in:
Manish Gupta
2021-07-22 21:40:53 -07:00
committed by GitHub
parent 6c29fe20ba
commit e5d51840e8
308 changed files with 32408 additions and 4722 deletions

View File

@ -66,6 +66,7 @@ struct B2bGemm {
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename B2bMma::IteratorA0::Params params_A0;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::Params params_B0;
@ -91,7 +92,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
@ -112,6 +113,7 @@ struct B2bGemm {
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
@ -211,7 +213,7 @@ struct B2bGemm {
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -315,7 +317,7 @@ struct B2bGemm {
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -209,6 +209,7 @@ struct B2bImplicitGemmConvolution {
cutlass::gemm::GemmCoord grid_tiled_shape;
gemm::GemmCoord implicit_gemm_problem_size_0;
gemm::GemmCoord implicit_gemm_problem_size_1;
int swizzle_log_tile;
int gemm_k_iterations_0;
int gemm_k_iterations_1;
typename B2bMma::IteratorA0::Params iterator_A0;
@ -233,7 +234,7 @@ struct B2bImplicitGemmConvolution {
//
CUTLASS_HOST_DEVICE
Params(): gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
///
CUTLASS_HOST_DEVICE
@ -245,7 +246,6 @@ struct B2bImplicitGemmConvolution {
problem_size_1(args.problem_size_1),
implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)),
implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)),
grid_tiled_shape(grid_tiled_shape),
iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())),
ptr_A0(args.ref_A0.data()),
iterator_B0(args.problem_size_0, args.ref_B0.layout()),
@ -272,6 +272,8 @@ struct B2bImplicitGemmConvolution {
implicit_gemm_problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.problem_size_0.split_k_slices);
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
}
};
@ -296,7 +298,7 @@ struct B2bImplicitGemmConvolution {
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_idx =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
@ -379,7 +381,7 @@ struct B2bImplicitGemmConvolution {
// Compute logical position within grid
threadblock_tile_idx =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// If performing a reduction via split-K, fetch the initial synchronization
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {

View File

@ -111,9 +111,7 @@ template <
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Beta is zero or not
bool IsBetaZero = false
typename Operator
>
struct DefaultB2bGemm;
@ -321,9 +319,7 @@ template <
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Is Beta zero or not
bool IsBetaZero>
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
@ -332,7 +328,7 @@ struct DefaultB2bGemm<
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, IsBetaZero> {
SplitKSerial, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -353,8 +349,7 @@ struct DefaultB2bGemm<
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
64 / sizeof_bits<ElementC>::value, InterleavedK,
IsBetaZero>::Epilogue;
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
@ -397,9 +392,7 @@ template <
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Is Beta zero or not
bool IsBetaZero>
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
@ -407,7 +400,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -426,8 +419,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
64 / sizeof_bits<ElementC>::value, InterleavedK,
IsBetaZero>::Epilogue;
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;