added support of b2b bmm (#849)
* added support of b2b bmm * fixed arguments and params structures * added batch_count argument * removed SplitKSerial and added new test case with b2b bmm * fixed support of Kbatched and added new test case with batch stride * added batch support for bias and scale * make test * small changes --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
d572cc1aab
commit
4a68cf748e
@ -49,10 +49,9 @@ namespace kernel {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct B2bGemm {
|
||||
|
||||
@ -61,7 +60,17 @@ struct B2bGemm {
|
||||
using OutputOp0 = typename B2bMma::OutputOp;
|
||||
using OutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
||||
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
||||
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
@ -69,6 +78,7 @@ struct B2bGemm {
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode;
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
@ -89,6 +99,13 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
@ -100,11 +117,12 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmUniversalMode mode,
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
@ -116,10 +134,18 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
int64_t batch_stride_A0,
|
||||
int64_t batch_stride_B0,
|
||||
int64_t batch_stride_B1,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_D1,
|
||||
int64_t batch_stride_Bias0,
|
||||
int64_t batch_stride_Scale0,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
@ -138,6 +164,13 @@ struct B2bGemm {
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
batch_stride_A0(batch_stride_A0),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_D1(batch_stride_D1),
|
||||
batch_stride_Bias0(batch_stride_Bias0),
|
||||
batch_stride_Scale0(batch_stride_Scale0),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1) {
|
||||
|
||||
@ -163,7 +196,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemm() { }
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
@ -223,7 +256,7 @@ struct B2bGemm {
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
@ -247,37 +280,64 @@ struct B2bGemm {
|
||||
return;
|
||||
}
|
||||
|
||||
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
||||
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
||||
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
||||
|
||||
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
||||
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
||||
|
||||
int offset_k_0 = 0;
|
||||
int offset_k_1 = 0;
|
||||
|
||||
int problem_size_k_0 = params.problem_size_0.k();
|
||||
int problem_size_k_1 = params.problem_size_1.k();
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_0 = min(
|
||||
problem_size_k_0,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_1 = min(
|
||||
problem_size_k_1,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
||||
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
||||
}
|
||||
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
||||
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
||||
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
||||
offset_k_1,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_0 = min(
|
||||
params.problem_size_0.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_1 = min(
|
||||
params.problem_size_1.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
@ -286,26 +346,25 @@ struct B2bGemm {
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
ptr_A0,
|
||||
{params.problem_size_0.m(), problem_size_k_0},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
ptr_B0,
|
||||
{problem_size_k_0, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
ptr_B1,
|
||||
{problem_size_k_1, params.problem_size_1.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
@ -313,7 +372,7 @@ struct B2bGemm {
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
ptr_Scale0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -323,7 +382,7 @@ struct B2bGemm {
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
ptr_Bias0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -349,11 +408,9 @@ struct B2bGemm {
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
@ -376,23 +433,32 @@ struct B2bGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
ptr_C1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -401,21 +467,21 @@ struct B2bGemm {
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
ptr_D1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
@ -427,14 +493,14 @@ struct B2bGemm {
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
@ -457,4 +523,3 @@ struct B2bGemm {
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -114,8 +114,6 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
@ -161,22 +159,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -228,8 +223,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator
|
||||
> {
|
||||
|
||||
@ -274,7 +266,7 @@ struct DefaultB2bGemm<
|
||||
Operator,
|
||||
EpilogueOutputOp0
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -323,20 +315,16 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, Stages, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -396,19 +384,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, 2, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -418,7 +403,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -112,22 +112,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
@ -179,8 +175,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -277,20 +270,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -350,19 +340,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
ThreadblockSwizzle, 2, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user