CUTLASS 2.6.1 - functional and performance enhancements to strided DGRAD, fixes, and tuning

* cutlass 2.6 update

* remove debug prints

* cutlass 2.6.1 (minor update)

* Updated CHANGELOG.

* Minor edit to readme to indicate patch version.

* Minor edit to readme.

Co-authored-by:  Haicheng Wu <haichengw@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
Manish Gupta
2021-09-03 10:26:15 -07:00
committed by GitHub
parent a01feb93d9
commit 6c2f8f2fb8
55 changed files with 317 additions and 315 deletions

View File

@ -47,6 +47,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/functional.h"
namespace cutlass {
namespace conv {
@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter(
return tile_m_per_filter;
}
// Computes starting Dx coord (h, w) for given starting filter postion
CUTLASS_HOST_DEVICE
void strided_dgrad_starting_coords(
Conv2dProblemSize const &problem_size,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int r, int s,
int &start_h, int &start_w) {
// function locals for remainder by fast divmod
int pad_h_rem_, pad_w_rem_;
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r));
stride_h_divmod.divmod(start_h, r_);
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s));
stride_w_divmod.divmod(start_w, s_);
}
} // namespace conv
} // namespace cutlass

View File

@ -217,14 +217,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
cutlass::Kernel<ImplicitGemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;

View File

@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
struct Params {
ConvProblemSize problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
FastDivmod filter_s_divmod;
FastDivmod stride_h_divmod;
FastDivmod stride_w_divmod;
int gemm_k_iterations;
typename Mma::IteratorA::Params iterator_A;
typename Mma::IteratorA::Element const *ptr_A;
@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
int *semaphore = nullptr
):
problem_size(args.problem_size),
filter_s_divmod(args.problem_size.stride_w),
stride_h_divmod(args.problem_size.stride_h),
stride_w_divmod(args.problem_size.stride_w),
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
ptr_A(args.ref_A.data()),
iterator_B(args.problem_size, args.ref_B.layout()),
@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
// int start_s = filter_tile_m % (params.problem_size.stride_w);
int start_r, start_s;
params.filter_s_divmod(start_r, start_s, filter_tile_m);
params.stride_w_divmod(start_r, start_s, filter_tile_m);
typename Mma::FragmentC accumulators;
@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.problem_size,
params.ptr_A,
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
MatrixCoord(
threadblock_tile_idx.m() * Mma::Shape::kM,
@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.ptr_D,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
threadblock_offset
);
@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.ptr_C,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
threadblock_offset
);

View File

@ -130,7 +130,6 @@ private:
int offset_p_[ThreadMap::Iterations::kStrided];
int offset_q_[ThreadMap::Iterations::kStrided];
public:
CUTLASS_HOST_DEVICE
@ -139,6 +138,7 @@ public:
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
):
@ -164,9 +164,12 @@ public:
}
// Starting h, w positions for filter position in gemm_k=0
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);
int start_h, start_w;
strided_dgrad_starting_coords(
problem_size_,
stride_h_divmod, stride_w_divmod,
filter_r, filter_s,
start_h, start_w);
// Effective P and Q for filter position required for remapping NHW rows
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;

View File

@ -200,7 +200,27 @@ private:
public:
/// Constructor
/// Constructor (output gradient (Dy) OperandA ctor)
CUTLASS_HOST_DEVICE
TileIteratorStridedDgrad(
Params const &params,
ConvProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
tile_access_iterator_(
params,
problem_size,
ptr,
thread_idx,
stride_h_divmod, stride_w_divmod,
start_r, start_s,
threadblock_offset) { }
/// Constructor (filter (w) OperandB ctor)
CUTLASS_HOST_DEVICE
TileIteratorStridedDgrad(
Params const &params,
@ -210,7 +230,12 @@ public:
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
tile_access_iterator_(params,
problem_size,
ptr,
thread_idx,
start_r, start_s,
threadblock_offset) { }
CUTLASS_HOST_DEVICE
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {

View File

@ -31,6 +31,12 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef CUTLASS_NAMESPACE
#define cutlass CUTLASS_NAMESPACE
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)
#if defined(_MSC_VER)

View File

@ -174,12 +174,12 @@ public:
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
if (Scale == ScaleType::Nothing)
return destination_converter(converted_accumulator);
ComputeFragment converted_source = source_converter(source);
// Perform binary operations
ComputeFragment intermediate;

View File

@ -309,9 +309,12 @@ struct DefaultEpilogueTensorOp {
kElementsPerAccess
>::Type;
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput
ElementOutput,
UseCUDAStore
>;
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,

View File

@ -62,7 +62,8 @@ namespace threadblock {
///
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_ ///< Element data type
typename Element_, ///< Element data type
bool UseCUDAStore = false
>
class PredicatedTileIterator {
public:
@ -341,10 +342,17 @@ public:
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
if (UseCUDAStore) {
if (guard) {
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {

View File

@ -222,6 +222,7 @@ public:
Element *pointer,
TensorCoord extent,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
TensorCoord threadblock_offset = TensorCoord()
):
@ -238,9 +239,12 @@ public:
s = (params_.problem_size.S - 1 - s);
}
// check if start_h_ and start_w_ are always positive
start_h_ = std::abs((params_.problem_size.pad_h - r) % params_.problem_size.stride_h);
start_w_ = std::abs((params_.problem_size.pad_w - s) % params_.problem_size.stride_w);
// compute starting coordinates in Dx start_h_ and start_w_
strided_dgrad_starting_coords(
params_.problem_size,
stride_h_divmod, stride_w_divmod,
r, s,
start_h_, start_w_);
p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h;
q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w;

View File

@ -256,20 +256,7 @@ public:
int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess;
#if 0
// Using inline PTX to avoid generic memory
AccessType *smem_ptr = pointers_[ptr_idx];
smem_ptr[offset] = frag_ptr[n];
#else
uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr);
uint32_t const *data = reinterpret_cast<uint32_t const *>(frag_ptr + n);
uint32_t offset_in_bytes = offset * sizeof(AccessType);
asm volatile(
"{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n"
: : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1])
);
#endif
ptr[offset] = frag_ptr[n];
}
}

View File

@ -455,14 +455,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

View File

@ -445,14 +445,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

View File

@ -423,14 +423,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

View File

@ -437,14 +437,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

View File

@ -438,14 +438,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;

View File

@ -352,14 +352,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);

View File

@ -325,14 +325,6 @@ public:
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;

View File

@ -103,8 +103,8 @@ template <
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for SM80 out-of-bound cp.async
bool UseZfill = false,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
///
typename Enable = void>
struct DefaultGemmWithKReduction {
@ -116,7 +116,7 @@ struct DefaultGemmWithKReduction {
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator, false, UseZfill>::ThreadblockMma;
Operator, false, SharedMemoryClear>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;

View File

@ -34,6 +34,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/arch/arch.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -130,7 +130,6 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, 2, Operator, false, SharedMemoryClearOption::kNone> {
static_assert(platform::is_same<LayoutC, layout::RowMajor>::value
|| platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
"simt epilogue must be row major");

View File

@ -141,8 +141,8 @@ struct DefaultMmaWithReductionCore {
using SmemLayoutB = typename Base::SmemLayoutB;
using WarpCount = typename Base::WarpCount;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp<

View File

@ -82,9 +82,10 @@ template <
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Use zfill or predicate for SM80 out-of-bound cp.async
bool UseZfill = false
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone
>
struct DefaultMmaWithReduction {
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
@ -122,7 +123,7 @@ struct DefaultMmaWithReduction {
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, Stages, UseZfill>;
typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -303,10 +303,8 @@ public:
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
@ -447,10 +445,8 @@ public:
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
@ -558,10 +554,8 @@ public:
}
--gemm_k_iterations;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so

View File

@ -231,10 +231,8 @@ public:
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
if (gemm_k_iterations <= 1) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
@ -302,10 +300,8 @@ public:
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations <= 2) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
warp_mma(accum, warp_frag_A[warp_mma_k % 2],

View File

@ -370,12 +370,10 @@ public:
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
if (gemm_k_iterations == 0) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.set_iteration_index(0);
iterator_A_imag.set_iteration_index(0);
@ -501,12 +499,10 @@ public:
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (gemm_k_iterations == 0) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
// Start issuing the first group of the next stage outside of the mainloop
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
@ -611,12 +607,10 @@ public:
}
--gemm_k_iterations;
if (gemm_k_iterations == 0) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
}
warp_mma_planar_complex(

View File

@ -308,13 +308,11 @@ public:
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
if (gemm_k_iterations <= 1) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
iterator_A_real.clear_mask(gemm_k_iterations <= 1);
iterator_A_imag.clear_mask(gemm_k_iterations <= 1);
iterator_B_real.clear_mask(gemm_k_iterations <= 1);
iterator_B_imag.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
@ -392,12 +390,10 @@ public:
++iterator_B_imag;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations <= 2) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
iterator_A_real.clear_mask(gemm_k_iterations <= 2);
iterator_A_imag.clear_mask(gemm_k_iterations <= 2);
iterator_B_real.clear_mask(gemm_k_iterations <= 2);
iterator_B_imag.clear_mask(gemm_k_iterations <= 2);
}
warp_mma_planar_complex(

View File

@ -196,10 +196,8 @@ public:
Operator warp_mma;
// Avoid reading out of bounds
if (gemm_k_iterations <= 1) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
//
// Mainloop
@ -247,10 +245,8 @@ public:
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations <= 2) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
}

View File

@ -379,11 +379,9 @@ public:
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
@ -500,11 +498,9 @@ public:
++this->warp_tile_iterator_B_;
++this->warp_tile_iterator_E_;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
@ -637,11 +633,9 @@ public:
}
--gemm_k_iterations;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so

View File

@ -78,7 +78,7 @@ template <
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
bool UseZfill = false,
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class MmaWithReductionMultistage :
@ -234,7 +234,7 @@ public:
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (UseZfill) {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
@ -269,7 +269,7 @@ public:
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (UseZfill) {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
@ -302,16 +302,14 @@ public:
//
// Prologue
//
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
@ -403,10 +401,8 @@ public:
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
@ -515,10 +511,8 @@ public:
}
--gemm_k_iterations;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
}
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so
@ -532,7 +526,7 @@ public:
}
if (UseZfill) {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();

View File

@ -49,7 +49,6 @@ class MmaTensorOpFragmentIterator;
// Partial specialization for col-major accumulator tile
// And Element type is the same as Accumulator Element type
template <
/// Shape of warp tile to load (concept: MatrixShape)
@ -58,13 +57,15 @@ template <
typename AccumulatorShape_,
/// KBlocks columns to compute residual
int KBlocksColumn_,
/// Accumulator Element type
typename ElementAccumulator_,
/// Element type
typename Element_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Output operation on fragment
typename OutputOp_>
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
cutlass::layout::ColumnMajor,
InstructionShape_, OutputOp_> {
public:
@ -78,6 +79,9 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
/// KBlocks columns to compute residual
static int const kKBlockColumn = KBlocksColumn_;
/// Accumulator Element type
using ElementAccumulator = ElementAccumulator_;
/// Element type
using Element = Element_;
@ -143,13 +147,14 @@ public:
using Fragment = Array<Element, Shape::kCount / kThreads>;
/// Accumulator Fragment object
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
private:
/// Internal access type
using AccessType = Array<Element, kElementsPerAccess>;
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentAccessType = Array<Element, kElementsPerAccess>;
private:
//
@ -203,10 +208,10 @@ public:
if (output_op.is_source_needed()) //beta must be zero
assert(0);
AccessType src_fragment;
FragmentAccessType src_fragment;
src_fragment.clear();
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
int index = index_ * MmaIterations::kCount;

View File

@ -14030,15 +14030,15 @@ struct Matrix<Element_, 4, 4> {
/// Returns a perspective projection matrix typical of OpenGL applications
CUTLASS_HOST_DEVICE
static Matrix perspective(Element near, Element far, Element fovH, Element fovV) {
static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) {
Element aspect = fovH / fovV;
Element f = Element(cos(fovV)) / Element(fovH);
Element Q = near - far;
Element Q = near_plane - far_plane;
return Matrix(
f / aspect, 0, 0, 0,
0, f, 0, 0,
0, 0, (near + far) / Q, Element(2) * far * near / Q,
0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q,
0, 0, -1, 0
);
}

View File

@ -245,10 +245,10 @@ class PredicatedTileAccessIteratorPredicates {
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
void clear_mask(bool enable = true) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kPredicateWordCount; ++i) {
predicates_[i] = 0u;
predicates_[i] = enable ? 0u : predicates_[i];
}
}
@ -551,8 +551,8 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
the_predicates.clear_mask();
void clear_mask(bool enable = true) {
the_predicates.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -741,7 +741,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -922,7 +922,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1224,7 +1224,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRankN<2>,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { the_predicates.clear_mask(); }
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1401,7 +1401,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2ColumnMa
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1578,7 +1578,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2RowMajor
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1764,7 +1764,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1948,7 +1948,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE

View File

@ -403,10 +403,10 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLi
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
void clear_mask(bool enable = true) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kPredicateWordCount; ++i) {
predicates_[i] = 0u;
predicates_[i] = enable ? 0u : predicates_[i];
}
}
@ -617,7 +617,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnM
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -796,7 +796,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajo
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE

View File

@ -288,7 +288,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { address_iterator_.clear_mask(); }
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -530,8 +530,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -738,8 +738,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -946,7 +946,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRa
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { address_iterator_.clear_mask(); }
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1184,8 +1184,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -1388,8 +1388,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -1600,7 +1600,7 @@ class PredicatedTileIterator<Shape_, Element_,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -1785,7 +1785,7 @@ class PredicatedTileIterator<Shape_, Element_,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { iterator_.clear_mask(); }
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE

View File

@ -293,7 +293,7 @@ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() { address_iterator_.clear_mask(); }
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
@ -525,8 +525,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently
@ -721,8 +721,8 @@ public:
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask() {
iterator_.clear_mask();
void clear_mask(bool enable = true) {
iterator_.clear_mask(enable);
}
/// Clears the predicate set efficiently