Add residual support for shmem staging iterator used in back-to-back GEMM fusion. This allows support of problem_size_0_n that is not multiple of 32. (#590)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu
2022-08-15 11:19:24 -04:00
committed by GitHub
parent e66bfcb1f8
commit 497b499d9d
8 changed files with 462 additions and 35 deletions

View File

@ -341,7 +341,7 @@ struct B2bGemm {
OutputOp0 output_op_0(params.output_op_0);
// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
typename B2bMma::FragmentC0 src_accum;
typename B2bMma::FragmentC1 accumulators;

View File

@ -267,7 +267,9 @@ public:
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
@ -639,7 +641,6 @@ public:
}
// 2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
@ -657,12 +658,11 @@ public:
tb_frag_A1_bias.clear();
iterator_A1_bias.load(tb_frag_A1_bias);
++iterator_A1_bias;
//
// Prologue
//
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
@ -750,9 +750,9 @@ public:
// Mainloop
//
gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1);
CUTLASS_PRAGMA_UNROLL
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
//
// Loop over GEMM K dimension
//

View File

@ -276,13 +276,15 @@ public:
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ),
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to

View File

@ -228,7 +228,8 @@ public:
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),

View File

@ -236,13 +236,14 @@ public:
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx),
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to

View File

@ -43,7 +43,7 @@
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
@ -158,11 +158,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
@ -303,11 +303,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
@ -436,11 +436,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
@ -574,11 +574,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >;
// Define the threadblock-scoped multistage matrix multiply