CUTLASS 2.6.1 - functional and performance enhancements to strided DGRAD, fixes, and tuning
* cutlass 2.6 update * remove debug prints * cutlass 2.6.1 (minor update) * Updated CHANGELOG. * Minor edit to readme to indicate patch version. * Minor edit to readme. Co-authored-by: Haicheng Wu <haichengw@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
@ -47,6 +47,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter(
|
||||
return tile_m_per_filter;
|
||||
}
|
||||
|
||||
// Computes starting Dx coord (h, w) for given starting filter postion
|
||||
CUTLASS_HOST_DEVICE
|
||||
void strided_dgrad_starting_coords(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int r, int s,
|
||||
int &start_h, int &start_w) {
|
||||
|
||||
// function locals for remainder by fast divmod
|
||||
int pad_h_rem_, pad_w_rem_;
|
||||
|
||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||
int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r));
|
||||
stride_h_divmod.divmod(start_h, r_);
|
||||
|
||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||
int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s));
|
||||
stride_w_divmod.divmod(start_w, s_);
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -217,14 +217,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
FastDivmod filter_s_divmod;
|
||||
FastDivmod stride_h_divmod;
|
||||
FastDivmod stride_w_divmod;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
filter_s_divmod(args.problem_size.stride_w),
|
||||
stride_h_divmod(args.problem_size.stride_h),
|
||||
stride_w_divmod(args.problem_size.stride_w),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
// int start_s = filter_tile_m % (params.problem_size.stride_w);
|
||||
|
||||
int start_r, start_s;
|
||||
params.filter_s_divmod(start_r, start_s, filter_tile_m);
|
||||
params.stride_w_divmod(start_r, start_s, filter_tile_m);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
@ -130,7 +130,6 @@ private:
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -139,6 +138,7 @@ public:
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
@ -164,9 +164,12 @@ public:
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
|
||||
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);
|
||||
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
problem_size_,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
// Effective P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
|
||||
@ -200,7 +200,27 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
/// Constructor (output gradient (Dy) OperandA ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(
|
||||
params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
/// Constructor (filter (w) OperandB ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
@ -210,7 +230,12 @@ public:
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
|
||||
tile_access_iterator_(params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
|
||||
|
||||
@ -31,6 +31,12 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_NAMESPACE
|
||||
#define cutlass CUTLASS_NAMESPACE
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
@ -174,12 +174,12 @@ public:
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
if (Scale == ScaleType::Nothing)
|
||||
return destination_converter(converted_accumulator);
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
@ -309,9 +309,12 @@ struct DefaultEpilogueTensorOp {
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
ElementOutput,
|
||||
UseCUDAStore
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
|
||||
@ -62,7 +62,8 @@ namespace threadblock {
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_ ///< Element data type
|
||||
typename Element_, ///< Element data type
|
||||
bool UseCUDAStore = false
|
||||
>
|
||||
class PredicatedTileIterator {
|
||||
public:
|
||||
@ -341,10 +342,17 @@ public:
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
if (UseCUDAStore) {
|
||||
if (guard) {
|
||||
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
|
||||
@ -222,6 +222,7 @@ public:
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
TensorCoord threadblock_offset = TensorCoord()
|
||||
):
|
||||
@ -238,9 +239,12 @@ public:
|
||||
s = (params_.problem_size.S - 1 - s);
|
||||
}
|
||||
|
||||
// check if start_h_ and start_w_ are always positive
|
||||
start_h_ = std::abs((params_.problem_size.pad_h - r) % params_.problem_size.stride_h);
|
||||
start_w_ = std::abs((params_.problem_size.pad_w - s) % params_.problem_size.stride_w);
|
||||
// compute starting coordinates in Dx start_h_ and start_w_
|
||||
strided_dgrad_starting_coords(
|
||||
params_.problem_size,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
r, s,
|
||||
start_h_, start_w_);
|
||||
|
||||
p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h;
|
||||
q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w;
|
||||
|
||||
@ -256,20 +256,7 @@ public:
|
||||
|
||||
|
||||
int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess;
|
||||
#if 0
|
||||
// Using inline PTX to avoid generic memory
|
||||
AccessType *smem_ptr = pointers_[ptr_idx];
|
||||
smem_ptr[offset] = frag_ptr[n];
|
||||
#else
|
||||
uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr);
|
||||
uint32_t const *data = reinterpret_cast<uint32_t const *>(frag_ptr + n);
|
||||
uint32_t offset_in_bytes = offset * sizeof(AccessType);
|
||||
|
||||
asm volatile(
|
||||
"{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n"
|
||||
: : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1])
|
||||
);
|
||||
#endif
|
||||
ptr[offset] = frag_ptr[n];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -455,14 +455,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -445,14 +445,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -423,14 +423,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -437,14 +437,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -438,14 +438,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -352,14 +352,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
@ -325,14 +325,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -103,8 +103,8 @@ template <
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for SM80 out-of-bound cp.async
|
||||
bool UseZfill = false,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DefaultGemmWithKReduction {
|
||||
@ -116,7 +116,7 @@ struct DefaultGemmWithKReduction {
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator, false, UseZfill>::ThreadblockMma;
|
||||
Operator, false, SharedMemoryClear>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -130,7 +130,6 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, 2, Operator, false, SharedMemoryClearOption::kNone> {
|
||||
|
||||
|
||||
static_assert(platform::is_same<LayoutC, layout::RowMajor>::value
|
||||
|| platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
|
||||
"simt epilogue must be row major");
|
||||
|
||||
@ -141,8 +141,8 @@ struct DefaultMmaWithReductionCore {
|
||||
using SmemLayoutB = typename Base::SmemLayoutB;
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp<
|
||||
|
||||
@ -82,9 +82,10 @@ template <
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Use zfill or predicate for SM80 out-of-bound cp.async
|
||||
bool UseZfill = false
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone
|
||||
>
|
||||
struct DefaultMmaWithReduction {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
@ -122,7 +123,7 @@ struct DefaultMmaWithReduction {
|
||||
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, Stages, UseZfill>;
|
||||
typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -303,10 +303,8 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -447,10 +445,8 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -558,10 +554,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -231,10 +231,8 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
@ -302,10 +300,8 @@ public:
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
||||
|
||||
@ -370,12 +370,10 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A_real.set_iteration_index(0);
|
||||
iterator_A_imag.set_iteration_index(0);
|
||||
@ -501,12 +499,10 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
|
||||
@ -611,12 +607,10 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
warp_mma_planar_complex(
|
||||
|
||||
@ -308,13 +308,11 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
iterator_B_real.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
@ -392,12 +390,10 @@ public:
|
||||
++iterator_B_imag;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
warp_mma_planar_complex(
|
||||
|
||||
@ -196,10 +196,8 @@ public:
|
||||
Operator warp_mma;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -247,10 +245,8 @@ public:
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -379,11 +379,9 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -500,11 +498,9 @@ public:
|
||||
++this->warp_tile_iterator_B_;
|
||||
++this->warp_tile_iterator_E_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -637,11 +633,9 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -78,7 +78,7 @@ template <
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
bool UseZfill = false,
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaWithReductionMultistage :
|
||||
@ -234,7 +234,7 @@ public:
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
@ -269,7 +269,7 @@ public:
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
@ -302,16 +302,14 @@ public:
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -403,10 +401,8 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -515,10 +511,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
@ -532,7 +526,7 @@ public:
|
||||
|
||||
}
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
|
||||
@ -49,7 +49,6 @@ class MmaTensorOpFragmentIterator;
|
||||
|
||||
|
||||
// Partial specialization for col-major accumulator tile
|
||||
// And Element type is the same as Accumulator Element type
|
||||
|
||||
template <
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
@ -58,13 +57,15 @@ template <
|
||||
typename AccumulatorShape_,
|
||||
/// KBlocks columns to compute residual
|
||||
int KBlocksColumn_,
|
||||
/// Accumulator Element type
|
||||
typename ElementAccumulator_,
|
||||
/// Element type
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Output operation on fragment
|
||||
typename OutputOp_>
|
||||
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
|
||||
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_, OutputOp_> {
|
||||
public:
|
||||
@ -78,6 +79,9 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
|
||||
/// KBlocks columns to compute residual
|
||||
static int const kKBlockColumn = KBlocksColumn_;
|
||||
|
||||
/// Accumulator Element type
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
@ -143,13 +147,14 @@ public:
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
|
||||
/// Accumulator Fragment object
|
||||
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<Element, kElementsPerAccess>;
|
||||
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentAccessType = Array<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
//
|
||||
@ -203,10 +208,10 @@ public:
|
||||
if (output_op.is_source_needed()) //beta must be zero
|
||||
assert(0);
|
||||
|
||||
AccessType src_fragment;
|
||||
FragmentAccessType src_fragment;
|
||||
src_fragment.clear();
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
|
||||
|
||||
int index = index_ * MmaIterations::kCount;
|
||||
|
||||
|
||||
@ -14030,15 +14030,15 @@ struct Matrix<Element_, 4, 4> {
|
||||
|
||||
/// Returns a perspective projection matrix typical of OpenGL applications
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Matrix perspective(Element near, Element far, Element fovH, Element fovV) {
|
||||
static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) {
|
||||
Element aspect = fovH / fovV;
|
||||
Element f = Element(cos(fovV)) / Element(fovH);
|
||||
Element Q = near - far;
|
||||
Element Q = near_plane - far_plane;
|
||||
|
||||
return Matrix(
|
||||
f / aspect, 0, 0, 0,
|
||||
0, f, 0, 0,
|
||||
0, 0, (near + far) / Q, Element(2) * far * near / Q,
|
||||
0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q,
|
||||
0, 0, -1, 0
|
||||
);
|
||||
}
|
||||
|
||||
@ -245,10 +245,10 @@ class PredicatedTileAccessIteratorPredicates {
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool enable = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kPredicateWordCount; ++i) {
|
||||
predicates_[i] = 0u;
|
||||
predicates_[i] = enable ? 0u : predicates_[i];
|
||||
}
|
||||
|
||||
}
|
||||
@ -551,8 +551,8 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
the_predicates.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
the_predicates.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -741,7 +741,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -922,7 +922,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1224,7 +1224,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRankN<2>,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { the_predicates.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1401,7 +1401,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2ColumnMa
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1578,7 +1578,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2RowMajor
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1764,7 +1764,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1948,7 +1948,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -403,10 +403,10 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLi
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool enable = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kPredicateWordCount; ++i) {
|
||||
predicates_[i] = 0u;
|
||||
predicates_[i] = enable ? 0u : predicates_[i];
|
||||
}
|
||||
|
||||
}
|
||||
@ -617,7 +617,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnM
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -796,7 +796,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajo
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -288,7 +288,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -530,8 +530,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -738,8 +738,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -946,7 +946,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRa
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1184,8 +1184,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -1388,8 +1388,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -1600,7 +1600,7 @@ class PredicatedTileIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1785,7 +1785,7 @@ class PredicatedTileIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -293,7 +293,7 @@ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -525,8 +525,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -721,8 +721,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
|
||||
Reference in New Issue
Block a user