Add residual support for shmem staging iterator used in back-to-back GEMM fusion. This allows support of problem_size_0_n that is not multiple of 32. (#590)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -341,7 +341,7 @@ struct B2bGemm {
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
||||
|
||||
typename B2bMma::FragmentC0 src_accum;
|
||||
typename B2bMma::FragmentC1 accumulators;
|
||||
|
||||
@ -267,7 +267,9 @@ public:
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
int lane_idx,
|
||||
///< GEMM0 N is used for accumulator extent
|
||||
int problem_size_0_n
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
@ -639,7 +641,6 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
@ -657,12 +658,11 @@ public:
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -750,9 +750,9 @@ public:
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
@ -276,13 +276,15 @@ public:
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
int lane_idx,
|
||||
///< GEMM0 N is used for accumulator extent
|
||||
int problem_size_0_n
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
|
||||
@ -228,7 +228,8 @@ public:
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
int lane_idx, ///< ID of each thread within a warp
|
||||
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
|
||||
@ -236,13 +236,14 @@ public:
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
int lane_idx, ///< ID of each thread within a warp
|
||||
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
|
||||
@ -43,7 +43,7 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
|
||||
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
|
||||
@ -158,11 +158,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
@ -303,11 +303,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
@ -436,11 +436,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
@ -574,11 +574,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >;
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
|
||||
Reference in New Issue
Block a user