added support of b2b bmm (#849)

* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Aleksandr Pivovar
2023-04-15 05:20:02 +02:00
committed by GitHub
parent d572cc1aab
commit 4a68cf748e
10 changed files with 765 additions and 409 deletions

View File

@ -49,10 +49,9 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct B2bGemm {
@ -61,7 +60,17 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA0 = typename B2bMma::IteratorA0::Element;
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
using ElementB0 = typename B2bMma::IteratorB0::Element;
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
using ElementB1 = typename B2bMma::IteratorB1::Element;
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
@ -69,6 +78,7 @@ struct B2bGemm {
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
@ -89,6 +99,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
@ -100,11 +117,12 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
@ -116,10 +134,18 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
int64_t batch_stride_A0,
int64_t batch_stride_B0,
int64_t batch_stride_B1,
int64_t batch_stride_C1,
int64_t batch_stride_D1,
int64_t batch_stride_Bias0,
int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
mode(mode),
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
@ -138,6 +164,13 @@ struct B2bGemm {
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
batch_stride_A0(batch_stride_A0),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C1(batch_stride_C1),
batch_stride_D1(batch_stride_D1),
batch_stride_Bias0(batch_stride_Bias0),
batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0),
output_op_1(output_op_1) {
@ -163,7 +196,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
B2bGemm() { }
B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@ -223,7 +256,7 @@ struct B2bGemm {
if(problem_size_0.n() > B2bMma::Shape0::kN)
return Status::kErrorInvalidProblem;
if(problem_size_1.n() > B2bMma::Shape1::kN)
return Status::kErrorInvalidProblem;
@ -247,37 +280,64 @@ struct B2bGemm {
return;
}
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
int offset_k_0 = 0;
int offset_k_1 = 0;
int problem_size_k_0 = params.problem_size_0.k();
int problem_size_k_1 = params.problem_size_1.k();
if (params.mode == GemmUniversalMode::kGemm) {
// Problem size is a function of threadblock index in the K dimension
problem_size_k_0 = min(
problem_size_k_0,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Problem size is a function of threadblock index in the K dimension
problem_size_k_1 = min(
problem_size_k_1,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1,
offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
@ -286,26 +346,25 @@ struct B2bGemm {
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
params.ref_A0.data(),
ptr_A0,
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
params.ref_B0.data(),
ptr_B0,
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
params.ref_B1.data(),
ptr_B1,
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
@ -313,7 +372,7 @@ struct B2bGemm {
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(),
ptr_Scale0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -323,7 +382,7 @@ struct B2bGemm {
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(),
ptr_Bias0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -349,11 +408,9 @@ struct B2bGemm {
src_accum.clear();
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
}
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
//
// Epilogue
@ -376,23 +433,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
ptr_C1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
@ -401,21 +467,21 @@ struct B2bGemm {
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
ptr_D1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
@ -427,14 +493,14 @@ struct B2bGemm {
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -457,4 +523,3 @@ struct B2bGemm {
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -114,8 +114,6 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
@ -161,22 +159,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -228,8 +223,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator
> {
@ -274,7 +266,7 @@ struct DefaultB2bGemm<
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -323,20 +315,16 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator> {
ThreadblockSwizzle, Stages, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -396,19 +384,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
ThreadblockSwizzle, 2, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -418,7 +403,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -112,22 +112,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, true> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
true
> {
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
false,
true
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -277,20 +270,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, true> {
Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -350,19 +340,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////