From 7ec3a87f22344bf11f8b411c72cd8759583da374 Mon Sep 17 00:00:00 2001 From: "mengchi.hmc" Date: Wed, 21 Apr 2021 14:28:58 +0800 Subject: [PATCH 1/4] support unalignment input for conv2d fprop stage=2 Fix for issue #242 --- .../conv/kernel/default_conv2d_fprop.h | 231 ++++++++++++++++++ ...rad_filter_tile_access_iterator_analytic.h | 2 + ...ad_filter_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 78 ++++-- ...rop_filter_tile_access_iterator_analytic.h | 2 + ...op_filter_tile_access_iterator_optimized.h | 78 ++++-- .../conv/threadblock/conv2d_tile_iterator.h | 21 +- ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...rad_filter_tile_access_iterator_analytic.h | 2 + ...ad_filter_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...rop_filter_tile_access_iterator_analytic.h | 2 + ...op_filter_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + .../threadblock/implicit_gemm_multistage.h | 64 +++-- 27 files changed, 444 insertions(+), 72 deletions(-) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index d22fb7f0..030e5ca5 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -66,6 +66,10 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// whether Matrix A is 128b aligned + bool AlignedA = true, + /// whether Matrix B is 128b aligned + bool AlignedB = true, conv::StrideSupport StrideSupport = StrideSupport::kStrided > struct DefaultConv2dFprop; @@ -515,6 +519,119 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and +/// multistage pipeline with unaligned data +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + bool AlignedA, + bool AlignedB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + AlignedA, + AlignedB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AlignedA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AlignedB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and /// multistage pipeline. template < @@ -729,6 +846,120 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm +/// and 2 stage pipeline with disalignment data +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + bool AlignedA, + bool AlignedB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + AlignedA, + AlignedB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AlignedA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AlignedB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm /// and 2 stage pipeline. template < diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index 8afb4968..026b2b2f 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -90,6 +90,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 937216d5..86e3140e 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static StrideSupport const kStrideSupport = StrideSupport_; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index e33e4ccb..edc42df1 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -111,6 +111,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 078c9e7f..06c3ecf4 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -100,6 +100,8 @@ public: using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index 51a51504..4943b9b7 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -95,6 +95,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 573255da..bb720cf7 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool Aligned = true > class Conv2dFpropActivationTileAccessIteratorOptimized { public: @@ -74,7 +75,8 @@ public: using Layout = Layout_; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; + using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; @@ -97,12 +99,15 @@ public: using Params = Conv2dFpropActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dFpropActivationIteratorOptimizedParams const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; // One pointer per access char const *pointer_[ThreadMap::Iterations::kStrided]; @@ -112,7 +117,7 @@ private: int filter_s_; int filter_c_; - Index masks_[ThreadMap::Iterations::kStrided][2]; + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; public: @@ -180,7 +185,11 @@ public: int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); - masks_[s_idx][0] |= (pred << r); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } } } @@ -197,13 +206,18 @@ public: int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; bool pred = (w >= 0 && w < problem_size_.W); - masks_[s_idx][1] |= (pred << s); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } } } - if (filter_c_ >= problem_size.C) { - clear_mask(); - } + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + } set_iteration_index(0); } @@ -250,7 +264,7 @@ private: /// Clears the predicates CUTLASS_HOST_DEVICE - void clear_mask_(bool clear) { + void clear_mask_(bool clear, int index) { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { @@ -268,10 +282,10 @@ private: " mov.u32 %0, m;\n" "}\n" : - "=r"(masks_[s][0]) + "=r"(masks_[s][index][0]) : "r"((int)clear), - "r"(masks_[s][0]) + "r"(masks_[s][index][0]) ); asm volatile( "{\n" @@ -283,15 +297,15 @@ private: " mov.u32 %0, m;\n" "}\n" : - "=r"(masks_[s][1]) + "=r"(masks_[s][index][1]) : "r"((int)clear), - "r"(masks_[s][1]) + "r"(masks_[s][index][1]) ); #else if (clear) { - masks_[s][0] = 0; - masks_[s][1] = 0; + masks_[s][index][0] = 0; + masks_[s][index][1] = 0; } #endif } @@ -302,8 +316,11 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of element @@ -338,7 +355,10 @@ public: filter_c_ += params_.filter_c_delta; } - clear_mask_(filter_c_ >= problem_size_.C); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + } } /// Clears the predicates @@ -346,8 +366,11 @@ public: void clear_mask() { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][0] = Mask(0); - masks_[s][1] = Mask(0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = Mask(0); + masks_[s][v][1] = Mask(0); + } } } @@ -355,21 +378,28 @@ public: bool valid() { return - (masks_[iteration_strided_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][1] & (Index(1) << filter_s_)); + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { - return reinterpret_cast(pointer_[iteration_strided_]); + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -390,7 +420,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (Aligned && problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index b0a89ada..48a51935 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -94,6 +94,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 2f12e41f..9781e42f 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool Aligned = true > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -73,7 +74,8 @@ public: using Element = Element_; using Layout = Layout_; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; + using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -89,6 +91,8 @@ public: static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // @@ -127,9 +131,10 @@ private: Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; - uint32_t predicates_; + uint32_t predicates_[kAccessesPerVector]; int filter_rs_; int filter_c_; @@ -154,7 +159,7 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - predicates_(0), + predicates_{0}, filter_rs_(0), filter_c_(0) { @@ -166,11 +171,14 @@ public: CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); - predicates_ |= (pred << s); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + predicates_[v_idx] |= (pred << s); + } } - if (filter_c_ >= problem_size.C) { - predicates_ = 0u; + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); } pointer_ += ( @@ -180,11 +188,44 @@ public: set_iteration_index(0); } + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear, int index) { + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(predicates_[index]) + : + "r"((int)clear), + "r"(predicates_[index]) + ); + #else + if (clear) { + predicates_[index] = 0; + predicates_[index] = 0; + } + #endif + } + /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -206,9 +247,9 @@ public: next = params_.inc_next_c; filter_c_ += params_.filter_c_delta; } - - if (filter_c_ >= problem_size_.C) { - predicates_ = 0; + + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); } pointer_ += next; @@ -217,18 +258,25 @@ public: /// Returns true if the current coordinate is within the filter tensor W CUTLASS_HOST_DEVICE bool valid() { - return (predicates_ & (1u << iteration_strided_)); + return (predicates_[iteration_vector_] & (1u << iteration_strided_)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { - return reinterpret_cast(pointer_); + return reinterpret_cast(pointer_) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -253,7 +301,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (Aligned && problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 61f02d19..68fec78a 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -131,16 +131,19 @@ public: CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[c + s * ThreadMap::Iterations::kContiguous], - tile_access_iterator_.get() + pointer_offset, - tile_access_iterator_.valid() - ); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) { + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); - ++tile_access_iterator_; + ++tile_access_iterator_; + } } } } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index 1e3a5837..cb79844d 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -89,6 +89,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index 7762d619..aae011b0 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -88,6 +88,8 @@ public: using Params = Conv2dWgradActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index 53fc9205..d9e12f87 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -89,6 +89,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index f138ef59..f4d7c7d4 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -88,6 +88,8 @@ public: using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index 01437547..fcbba130 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or larger."); + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index ee532ff6..8683d1d5 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static StrideSupport const kStrideSupport = StrideSupport_; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 1d70ab3d..92782550 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -96,6 +96,8 @@ public: static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or greater."); + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Simpligying assertions diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 2a62c292..53d1beec 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -101,6 +101,8 @@ public: using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index 7cadf860..1c148d5b 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -94,6 +94,8 @@ public: using Params = Conv3dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index 9246c592..74c559c1 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -97,6 +97,8 @@ public: using Params = Conv3dFpropActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index a7f54368..272fc246 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -93,6 +93,8 @@ public: using Params = Conv3dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index 5d814890..0fd161c5 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -89,6 +89,8 @@ public: static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index 396d856a..0e7ab37e 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 2835480d..0052fd67 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index b8af8efa..73e96d4a 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index d3b356e0..0bf96aff 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index aefdcd6d..cbc35f32 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -195,7 +195,8 @@ public: IteratorA &iterator_A, IteratorB &iterator_B, int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A); + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); this->smem_iterator_A_.set_iteration_index(group_start_A); // Async Copy for operand A @@ -208,18 +209,22 @@ public: this->smem_iterator_A_.get()); int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + ++iterator_A; + } ++this->smem_iterator_A_; } } - iterator_B.set_iteration_index(group_start_B); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); this->smem_iterator_B_.set_iteration_index(group_start_B); @@ -232,12 +237,15 @@ public: this->smem_iterator_B_.get()); int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + ++iterator_B; + } ++this->smem_iterator_B_; } } @@ -279,14 +287,18 @@ public: reinterpret_cast( this->smem_iterator_A_.get()); - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - ++iterator_A; + ++iterator_A; + } ++this->smem_iterator_A_; } @@ -300,14 +312,18 @@ public: reinterpret_cast( this->smem_iterator_B_.get()); - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - ++iterator_B; + ++iterator_B; + } ++this->smem_iterator_B_; } From bb35a3ba6f674a24686f25579f914b7b6692c5d7 Mon Sep 17 00:00:00 2001 From: "mengchi.hmc" Date: Thu, 22 Apr 2021 15:20:57 +0800 Subject: [PATCH 2/4] support setting load granularity for conv2d fprop --- .../conv/kernel/default_conv2d_fprop.h | 333 ++++-------------- ...ctivation_tile_access_iterator_optimized.h | 9 +- ...op_filter_tile_access_iterator_optimized.h | 9 +- 3 files changed, 86 insertions(+), 265 deletions(-) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 030e5ca5..88096b8e 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -66,10 +66,10 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - /// whether Matrix A is 128b aligned - bool AlignedA = true, - /// whether Matrix B is 128b aligned - bool AlignedB = true, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value, conv::StrideSupport StrideSupport = StrideSupport::kStrided > struct DefaultConv2dFprop; @@ -94,7 +94,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -113,7 +115,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -197,7 +201,9 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, - int InterleavedK + int InterleavedK, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -216,7 +222,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -312,7 +320,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -331,7 +341,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -417,7 +429,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, - int InterleavedK + int InterleavedK, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -436,7 +450,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -520,7 +536,7 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and -/// multistage pipeline with unaligned data +/// multistage pipeline. template < typename ElementA, typename LayoutA, @@ -537,8 +553,8 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, - bool AlignedA, - bool AlignedB + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -558,8 +574,8 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, - AlignedA, - AlignedB + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -577,7 +593,7 @@ struct DefaultConv2dFprop < ElementA, LayoutA, ThreadMapA, - AlignedA + AlignmentA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; @@ -590,7 +606,7 @@ struct DefaultConv2dFprop < ElementB, LayoutB, ThreadMapB, - AlignedB + AlignmentB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -607,114 +623,7 @@ struct DefaultConv2dFprop < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Global, - MmaPolicy, - Stages - >; - - // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp, - EpilogueOutputOp::kCount - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and -/// multistage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - int Stages, - typename MathOperatorTag -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, MathOperatorTag - >; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmMultistage< - ThreadblockShape, - IteratorA, - SmemIteratorA, arch::CacheOperation::Always, - IteratorB, - SmemIteratorB, - arch::CacheOperation::Global, MmaPolicy, Stages >; @@ -755,7 +664,9 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, - int InterleavedK + int InterleavedK, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -774,7 +685,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -844,10 +757,8 @@ struct DefaultConv2dFprop < >; }; -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm -/// and 2 stage pipeline with disalignment data +/// and 2 stage pipeline. template < typename ElementA, typename LayoutA, @@ -863,8 +774,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, - bool AlignedA, - bool AlignedB + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -884,8 +795,8 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, - AlignedA, - AlignedB + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -903,7 +814,7 @@ struct DefaultConv2dFprop < ElementA, LayoutA, ThreadMapA, - AlignedA + AlignmentA > >; @@ -918,115 +829,7 @@ struct DefaultConv2dFprop < ElementB, LayoutB, ThreadMapB, - AlignedB - > - >; - - using SmemIteratorB = typename MmaCore::SmemIteratorB; - - // Warp-level GEMM components - using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; - using MmaPolicy = typename MmaCore::MmaPolicy; - - // Define the Mma - using Mma = threadblock::ImplicitGemmPipelined< - ThreadblockShape, - IteratorA, - SmemIteratorA, - IteratorB, - SmemIteratorB, - ElementC, - LayoutC, - MmaPolicy - >; - - // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< - ArchTag, - ThreadblockShape, - WarpMmaTensorOp, - 1, - EpilogueOutputOp - >::Epilogue; - - // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< - Mma, - Epilogue, - ThreadblockSwizzle, - conv::Operator::kFprop - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm -/// and 2 stage pipeline. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementAccumulator, - typename ArchTag, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename EpilogueOutputOp, - typename ThreadblockSwizzle, - typename MathOperatorTag -> -struct DefaultConv2dFprop < - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - arch::OpClassTensorOp, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - 2, - MathOperatorTag, - IteratorAlgorithm::kOptimized -> { - - // Define the core components from GEMM - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - 2, MathOperatorTag>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementA, - LayoutA, - ThreadMapA - > - >; - - using SmemIteratorA = typename MmaCore::SmemIteratorA; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using IteratorB = - cutlass::conv::threadblock::TileIterator< - cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< - cutlass::MatrixShape, - ElementB, - LayoutB, - ThreadMapB + AlignmentB > >; @@ -1083,7 +886,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, - int InterleavedK + int InterleavedK, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -1102,7 +907,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1194,7 +1001,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -1213,7 +1022,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1299,7 +1110,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -1318,7 +1131,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1404,7 +1219,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -1423,7 +1240,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1510,7 +1329,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dFprop < ElementA, @@ -1529,7 +1350,9 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + AlignmentA, + AlignmentB > { // Define the core components from GEMM diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index bb720cf7..9a272b05 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -61,7 +61,7 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - bool Aligned = true + int AccessSize = ThreadMap_::kElementsPerAccess > class Conv2dFpropActivationTileAccessIteratorOptimized { public: @@ -75,7 +75,6 @@ public: using Layout = Layout_; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; @@ -216,7 +215,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } set_iteration_index(0); @@ -357,7 +356,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } } @@ -420,7 +419,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (Aligned && problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessSize) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 9781e42f..3229d6c5 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -61,7 +61,7 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - bool Aligned = true + int AccessSize = ThreadMap_::kElementsPerAccess > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -74,7 +74,6 @@ public: using Element = Element_; using Layout = Layout_; using ThreadMap = ThreadMap_; - static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -178,7 +177,7 @@ public: } for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } pointer_ += ( @@ -249,7 +248,7 @@ public: } for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } pointer_ += next; @@ -301,7 +300,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (Aligned && problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessSize) { return Status::kErrorInvalidProblem; } From f4b0a336339de6acd56134ddc7bc637bf6e772ac Mon Sep 17 00:00:00 2001 From: "mengchi.hmc" Date: Fri, 23 Apr 2021 14:33:46 +0800 Subject: [PATCH 3/4] add unit test for non int4 load --- .../conv/kernel/default_conv2d_fprop.h | 7 +- ...t_gradient_tile_access_iterator_analytic.h | 2 + ...op_filter_tile_access_iterator_optimized.h | 3 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu | 84 ++++++++++++ ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu | 129 ++++++++++++++++++ ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu | 88 ++++++++++++ 6 files changed, 311 insertions(+), 2 deletions(-) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 88096b8e..0ddbe6b3 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -615,6 +615,11 @@ struct DefaultConv2dFprop < using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; using MmaPolicy = typename MmaCore::MmaPolicy; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + // Define the Mma using Mma = threadblock::ImplicitGemmMultistage< ThreadblockShape, @@ -623,7 +628,7 @@ struct DefaultConv2dFprop < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Always, + CacheOpB, MmaPolicy, Stages >; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index edc42df1..e8fec8c1 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -341,6 +341,8 @@ public: // Parameters structure // + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + struct Params { Layout layout; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 3229d6c5..8674a530 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -176,6 +176,7 @@ public: } } + CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } @@ -212,7 +213,6 @@ public: #else if (clear) { predicates_[index] = 0; - predicates_[index] = 0; } #endif } @@ -247,6 +247,7 @@ public: filter_c_ += params_.filter_c_delta; } + CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); } diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu index 3366f1b5..5b13841e 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu @@ -117,5 +117,89 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten EXPECT_TRUE(test::conv::device::TestAllConv2d()); } +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, + 128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 2, + 2 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, + 128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 4, + 4 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu index 7b74e128..1dacc2e9 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu @@ -117,5 +117,134 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten EXPECT_TRUE(test::conv::device::TestAllConv2d()); } +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align1, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 1, + 1 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 2, + 2 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 4, + 4 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu index 4c7b3d77..cad20695 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -77,5 +77,93 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te EXPECT_TRUE(test::conv::device::TestAllConv2d()); } +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::tfloat32_t; + using ElementB = cutlass::tfloat32_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 1, + 1 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::tfloat32_t; + using ElementB = cutlass::tfloat32_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + 2, + 2 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED From 59e2aa505ab70d48b4acb1ea3231d2c013235215 Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Wed, 8 Sep 2021 13:14:08 +0000 Subject: [PATCH 4/4] refine the implementation --- .../conv/device/implicit_gemm_convolution.h | 12 ++ .../conv/kernel/default_conv2d_dgrad.h | 180 +++++++++++++----- .../conv/kernel/default_conv2d_fprop.h | 83 ++++++-- .../default_conv2d_fprop_with_broadcast.h | 10 +- .../default_conv2d_fprop_with_reduction.h | 10 +- .../conv/kernel/default_conv2d_wgrad.h | 125 +++++++++--- .../conv/kernel/default_conv3d_dgrad.h | 1 - .../conv/kernel/default_conv3d_wgrad.h | 1 - ...rad_filter_tile_access_iterator_analytic.h | 72 ++++--- ...ad_filter_tile_access_iterator_optimized.h | 59 ++++-- ...t_gradient_tile_access_iterator_analytic.h | 76 +++++--- ..._gradient_tile_access_iterator_optimized.h | 132 ++++++------- ...activation_tile_access_iterator_analytic.h | 31 ++- ...ctivation_tile_access_iterator_optimized.h | 97 +++------- ...rop_filter_tile_access_iterator_analytic.h | 33 ++-- ...op_filter_tile_access_iterator_optimized.h | 57 ++---- .../conv/threadblock/conv2d_tile_iterator.h | 11 +- ...activation_tile_access_iterator_analytic.h | 53 ++++-- ...ctivation_tile_access_iterator_optimized.h | 68 +++++-- ...t_gradient_tile_access_iterator_analytic.h | 35 ++-- ..._gradient_tile_access_iterator_optimized.h | 62 +++--- ...rad_filter_tile_access_iterator_analytic.h | 5 +- ...ad_filter_tile_access_iterator_optimized.h | 8 +- ...t_gradient_tile_access_iterator_analytic.h | 7 +- ..._gradient_tile_access_iterator_optimized.h | 5 +- ...activation_tile_access_iterator_analytic.h | 3 +- ...ctivation_tile_access_iterator_optimized.h | 4 +- ...rop_filter_tile_access_iterator_analytic.h | 3 +- ...op_filter_tile_access_iterator_optimized.h | 5 +- ...activation_tile_access_iterator_analytic.h | 4 +- ...ctivation_tile_access_iterator_optimized.h | 4 +- ...t_gradient_tile_access_iterator_analytic.h | 4 +- ..._gradient_tile_access_iterator_optimized.h | 4 +- .../threadblock/implicit_gemm_multistage.h | 18 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu | 113 +++++++++++ ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu | 149 +++++++++++---- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu | 105 ++++++++-- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu | 131 ++++++++----- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu | 63 ++---- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu | 55 ++++++ test/unit/conv/device/conv2d_testbed.h | 1 + ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu | 111 +++++++++++ ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu | 113 ++++++++++- tools/library/scripts/conv2d_operation.py | 14 +- tools/library/scripts/generator.py | 162 ++++++++-------- 45 files changed, 1593 insertions(+), 706 deletions(-) diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index eed01b5d..da4c8029 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -105,6 +105,18 @@ public: return status; } + static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + // check for unsupported problem sizes for strided dgrad implementation if (kConvolutionalOperator == conv::Operator::kDgrad && kStrideSupport == conv::StrideSupport::kStrided) { diff --git a/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/include/cutlass/conv/kernel/default_conv2d_dgrad.h index 9adda4dc..7c3e29e2 100644 --- a/include/cutlass/conv/kernel/default_conv2d_dgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -66,7 +66,11 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value > struct DefaultConv2dDgrad; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +94,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -110,7 +116,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided + StrideSupport::kStrided, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -121,24 +129,28 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kStrided + StrideSupport::kStrided, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, ThreadMapB, - StrideSupport::kStrided + StrideSupport::kStrided, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -147,6 +159,11 @@ struct DefaultConv2dDgrad < using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; using MmaPolicy = typename MmaCore::MmaPolicy; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + // Define the Mma using Mma = threadblock::ImplicitGemmMultistage< ThreadblockShape, @@ -155,7 +172,7 @@ struct DefaultConv2dDgrad < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Global, + CacheOpB, MmaPolicy, Stages >; @@ -196,7 +213,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -216,7 +235,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided + StrideSupport::kStrided, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -227,13 +248,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIteratorStridedDgrad< cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kStrided + StrideSupport::kStrided, + AccessTypeA > >; @@ -241,13 +264,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIteratorStridedDgrad< cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, ThreadMapB, - StrideSupport::kStrided + StrideSupport::kStrided, + AccessTypeB > >; @@ -308,7 +333,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -328,7 +355,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -339,24 +368,28 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, ThreadMapB, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -365,6 +398,11 @@ struct DefaultConv2dDgrad < using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; using MmaPolicy = typename MmaCore::MmaPolicy; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + // Define the Mma using Mma = threadblock::ImplicitGemmMultistage< ThreadblockShape, @@ -373,7 +411,7 @@ struct DefaultConv2dDgrad < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Global, + CacheOpB, MmaPolicy, Stages >; @@ -414,7 +452,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -434,7 +474,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -445,13 +487,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeA > >; @@ -459,13 +503,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, ThreadMapB, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeB > >; @@ -526,7 +572,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -546,7 +594,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -557,23 +607,28 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -582,6 +637,11 @@ struct DefaultConv2dDgrad < using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; using MmaPolicy = typename MmaCore::MmaPolicy; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + // Define the Mma using Mma = threadblock::ImplicitGemmMultistage< ThreadblockShape, @@ -590,7 +650,7 @@ struct DefaultConv2dDgrad < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Global, + CacheOpB, MmaPolicy, Stages >; @@ -631,7 +691,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -651,7 +713,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -662,13 +726,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kUnity + StrideSupport::kUnity, + AccessTypeA > >; @@ -676,12 +742,15 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB > >; @@ -744,7 +813,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag> + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> struct DefaultConv2dDgrad < ElementA, LayoutA, @@ -763,7 +835,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -848,7 +922,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag> + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> struct DefaultConv2dDgrad < ElementA, LayoutA, @@ -867,7 +944,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -955,7 +1034,9 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -975,7 +1056,9 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1040,10 +1123,8 @@ struct DefaultConv2dDgrad < ThreadblockSwizzle, conv::Operator::kDgrad >; - }; - ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1063,7 +1144,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -1083,7 +1166,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kUnity + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1169,7 +1254,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -1189,7 +1276,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, - conv::StrideSupport::kStrided + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1257,7 +1346,6 @@ struct DefaultConv2dDgrad < ThreadblockSwizzle, conv::Operator::kDgrad >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1278,7 +1366,9 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + int AlignmentA, + int AlignmentB > struct DefaultConv2dDgrad < ElementA, @@ -1298,7 +1388,9 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, - StrideSupport::kUnity + StrideSupport::kUnity, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -1368,10 +1460,10 @@ struct DefaultConv2dDgrad < >; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace conv } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 2798a419..205da469 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -66,11 +66,11 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements - int AlignmentB = 128 / cutlass::sizeof_bits::value, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + int AlignmentB = 128 / cutlass::sizeof_bits::value > struct DefaultConv2dFprop; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -95,6 +95,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -116,6 +117,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -128,22 +130,26 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, LayoutA, - ThreadMapA + ThreadMapA, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, LayoutB, - ThreadMapB + ThreadMapB, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -152,6 +158,11 @@ struct DefaultConv2dFprop < using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; using MmaPolicy = typename MmaCore::MmaPolicy; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + // Define the Mma using Mma = threadblock::ImplicitGemmMultistage< ThreadblockShape, @@ -160,7 +171,7 @@ struct DefaultConv2dFprop < arch::CacheOperation::Always, IteratorB, SmemIteratorB, - arch::CacheOperation::Global, + CacheOpB, MmaPolicy, Stages >; @@ -203,9 +214,10 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, - int InterleavedK, + conv::StrideSupport StrideSupport, int AlignmentA, - int AlignmentB + int AlignmentB, + int InterleavedK > struct DefaultConv2dFprop < ElementA, @@ -225,6 +237,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -325,6 +338,7 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -346,6 +360,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -358,12 +373,14 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, LayoutA, - ThreadMapA + ThreadMapA, + AccessTypeA > >; @@ -371,12 +388,14 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, LayoutB, - ThreadMapB + ThreadMapB, + AccessTypeB > >; @@ -435,9 +454,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, - int InterleavedK, + conv::StrideSupport StrideSupport, int AlignmentA, - int AlignmentB + int AlignmentB, + int InterleavedK > struct DefaultConv2dFprop < ElementA, @@ -457,6 +477,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -561,6 +582,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -582,6 +604,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { @@ -595,26 +618,28 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< cutlass::MatrixShape, ElementA, LayoutA, ThreadMapA, - AlignmentA + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, LayoutB, ThreadMapB, - AlignmentB + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -624,7 +649,7 @@ struct DefaultConv2dFprop < using MmaPolicy = typename MmaCore::MmaPolicy; static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * AlignmentB) == 128) + ((sizeof_bits::value * AlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; @@ -679,9 +704,10 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, - int InterleavedK, + conv::StrideSupport StrideSupport, int AlignmentA, - int AlignmentB + int AlignmentB, + int InterleavedK > struct DefaultConv2dFprop < ElementA, @@ -701,6 +727,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { @@ -774,6 +801,8 @@ struct DefaultConv2dFprop < >; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm /// and 2 stage pipeline. template < @@ -791,6 +820,7 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -812,6 +842,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { @@ -824,6 +855,7 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< @@ -831,7 +863,7 @@ struct DefaultConv2dFprop < ElementA, LayoutA, ThreadMapA, - AlignmentA + AccessTypeA > >; @@ -839,6 +871,7 @@ struct DefaultConv2dFprop < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< @@ -846,7 +879,7 @@ struct DefaultConv2dFprop < ElementB, LayoutB, ThreadMapB, - AlignmentB + AccessTypeB > >; @@ -905,9 +938,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, - int InterleavedK, + conv::StrideSupport StrideSupport, int AlignmentA, - int AlignmentB + int AlignmentB, + int InterleavedK > struct DefaultConv2dFprop < ElementA, @@ -927,6 +961,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { @@ -1023,6 +1058,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -1044,6 +1080,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -1132,6 +1169,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -1153,6 +1191,7 @@ struct DefaultConv2dFprop < Stages, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { @@ -1241,6 +1280,7 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -1262,6 +1302,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, + StrideSupport, AlignmentA, AlignmentB > { @@ -1351,6 +1392,7 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, typename MathOperatorTag, + conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB > @@ -1372,6 +1414,7 @@ struct DefaultConv2dFprop < 2, MathOperatorTag, IteratorAlgorithm::kOptimized, + StrideSupport, AlignmentA, AlignmentB > { diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index 11a01ac5..6f127440 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -65,7 +65,11 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value > struct DefaultConv2dFpropWithBroadcast { @@ -84,7 +88,9 @@ struct DefaultConv2dFpropWithBroadcast { Stages, MathOperatorTag, IteratorAlgorithm, - StrideSupport + StrideSupport, + AlignmentA, + AlignmentB >::Kernel; // Replace epilogue diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index baf49a8a..5dd31052 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -66,7 +66,11 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value > struct DefaultConv2dFpropWithReduction { @@ -85,7 +89,9 @@ struct DefaultConv2dFpropWithReduction { Stages, MathOperatorTag, IteratorAlgorithm, - StrideSupport + StrideSupport, + AlignmentA, + AlignmentB >::Kernel; // Replace epilogue diff --git a/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/include/cutlass/conv/kernel/default_conv2d_wgrad.h index 336ebb3b..91edde2d 100644 --- a/include/cutlass/conv/kernel/default_conv2d_wgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -67,8 +67,13 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value > struct DefaultConv2dWgrad; + ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,7 +98,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB > struct DefaultConv2dWgrad < ElementA, @@ -112,7 +120,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -123,22 +134,26 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, - ThreadMapA + ThreadMapA, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -179,6 +194,7 @@ struct DefaultConv2dWgrad < conv::Operator::kWgrad >; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and two @@ -198,7 +214,10 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB > struct DefaultConv2dWgrad < ElementA, @@ -217,7 +236,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -228,12 +250,14 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, - ThreadMapA + ThreadMapA, + AccessTypeA > >; @@ -241,12 +265,14 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + AccessTypeB > >; @@ -308,7 +334,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB > struct DefaultConv2dWgrad < ElementA, @@ -327,7 +356,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -338,22 +370,26 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< cutlass::MatrixShape, ElementA, - ThreadMapA + ThreadMapA, + AccessTypeA >; using SmemIteratorA = typename MmaCore::SmemIteratorA; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + AccessTypeB >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -394,6 +430,7 @@ struct DefaultConv2dWgrad < conv::Operator::kWgrad >; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and two @@ -413,7 +450,10 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB > struct DefaultConv2dWgrad < ElementA, @@ -432,7 +472,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB > { // Define the core components from GEMM @@ -443,12 +486,14 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; using IteratorA = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< cutlass::MatrixShape, ElementA, - ThreadMapA + ThreadMapA, + AccessTypeA > >; @@ -456,12 +501,14 @@ struct DefaultConv2dWgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; using IteratorB = cutlass::conv::threadblock::TileIterator< cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + AccessTypeB > >; @@ -524,7 +571,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB > struct DefaultConv2dWgrad < ElementA, @@ -543,7 +593,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB > { // Define the core components from GEMM @@ -629,7 +682,10 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB > struct DefaultConv2dWgrad < ElementA, @@ -648,7 +704,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB > { // Define the core components from GEMM @@ -732,7 +791,10 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB > struct DefaultConv2dWgrad < ElementA, @@ -751,7 +813,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB > { // Define the core components from GEMM @@ -817,7 +882,6 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, conv::Operator::kWgrad >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -838,7 +902,10 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB > struct DefaultConv2dWgrad < ElementA, @@ -857,7 +924,10 @@ struct DefaultConv2dWgrad < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB > { // Define the core components from GEMM @@ -925,12 +995,11 @@ struct DefaultConv2dWgrad < >; }; -///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace conv } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/include/cutlass/conv/kernel/default_conv3d_dgrad.h index 6554099a..d51ff21c 100644 --- a/include/cutlass/conv/kernel/default_conv3d_dgrad.h +++ b/include/cutlass/conv/kernel/default_conv3d_dgrad.h @@ -228,7 +228,6 @@ struct DefaultConv3dDgrad < // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using IteratorA = cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< cutlass::MatrixShape, diff --git a/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/include/cutlass/conv/kernel/default_conv3d_wgrad.h index 3fcdc7d0..728cf945 100644 --- a/include/cutlass/conv/kernel/default_conv3d_wgrad.h +++ b/include/cutlass/conv/kernel/default_conv3d_wgrad.h @@ -501,4 +501,3 @@ struct DefaultConv3dWgrad < } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index a7f4dc0b..da8ae974 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -59,7 +59,8 @@ template < typename Shape_, typename Element_, typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray > class Conv2dDgradFilterTileAccessIteratorAnalytic; @@ -70,13 +71,15 @@ class Conv2dDgradFilterTileAccessIteratorAnalytic; template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradFilterTileAccessIteratorAnalytic < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kStrided + conv::StrideSupport::kStrided, + AccessType_ > { public: @@ -88,7 +91,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -97,7 +100,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or larger."); @@ -107,14 +115,13 @@ public: using Params = Conv2dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension @@ -162,8 +169,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -213,7 +222,7 @@ public: TensorCoord coord = at(); - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C; } /// Returns a pointer to the vector starting at the current coordinate @@ -223,13 +232,19 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -249,7 +264,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } @@ -263,13 +278,15 @@ public: template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradFilterTileAccessIteratorAnalytic < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kUnity + conv::StrideSupport::kUnity, + AccessType_ >{ public: @@ -281,7 +298,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -290,7 +307,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or larger."); @@ -306,6 +328,7 @@ private: Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension @@ -348,8 +371,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -395,7 +420,7 @@ public: TensorCoord coord = at(); - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C; } /// Returns a pointer to the vector starting at the current coordinate @@ -405,13 +430,18 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); - + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -431,7 +461,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } @@ -446,5 +476,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 8638bc63..723a8f6d 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray > class Conv2dDgradFilterTileAccessIteratorOptimized; @@ -71,13 +72,15 @@ class Conv2dDgradFilterTileAccessIteratorOptimized; template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradFilterTileAccessIteratorOptimized < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kUnity + conv::StrideSupport::kUnity, + AccessType_ > { public: @@ -89,7 +92,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -98,9 +101,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; - + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + // // Parameters structure // @@ -141,9 +147,10 @@ private: Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; - uint32_t predicates_; + uint32_t predicates_[kAccessesPerVector]; int filter_rs_; int filter_k_; @@ -169,7 +176,7 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - predicates_(0), + predicates_{0}, filter_rs_(0), filter_k_(0) { @@ -186,11 +193,15 @@ public: int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; int filter_c = column + c * ThreadMap::Delta::kContiguous; - uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_ |= (pred << pred_idx); + uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } } } @@ -204,8 +215,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -234,7 +247,11 @@ public: for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - predicates_ = (predicates_ & (~kClearMask)); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } } } @@ -245,19 +262,25 @@ public: CUTLASS_HOST_DEVICE bool valid() { LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_ & (1u << pred_idx)); + return (predicates_[iteration_vector_] & (1u << pred_idx)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { return reinterpret_cast(pointer_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -282,7 +305,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } @@ -297,5 +320,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index efb430dd..cb8410e4 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -59,7 +59,8 @@ template < typename Shape_, typename Element_, typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, + typename AccessType_ = cutlass::AlignedArray > class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -69,13 +70,15 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kStrided + conv::StrideSupport::kStrided, + AccessType_ > { public: @@ -86,7 +89,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -95,7 +98,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or greater."); @@ -112,14 +120,13 @@ public: using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; int filter_k_; @@ -211,8 +218,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -277,7 +286,7 @@ public: coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.P && coord.w() >= 0 && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; + (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K; } /// Returns a pointer to the vector starting at the current coordinate @@ -287,12 +296,18 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -312,14 +327,14 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// // Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by @@ -327,13 +342,15 @@ public: template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kUnity + conv::StrideSupport::kUnity, + AccessType_ > { public: @@ -344,7 +361,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -353,7 +370,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or greater."); @@ -368,8 +390,6 @@ public: // Parameters structure // - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - struct Params { Layout layout; @@ -395,6 +415,7 @@ private: Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; int filter_k_; @@ -446,8 +467,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -497,7 +520,6 @@ public: } - /// Returns true if the current coordinate is within the output tensor Dy CUTLASS_HOST_DEVICE bool valid() const { @@ -507,7 +529,7 @@ public: return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.P && coord.w() >= 0 && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; + (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K; } /// Returns a pointer to the vector starting at the current coordinate @@ -517,12 +539,18 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -548,7 +576,7 @@ public: } // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } @@ -556,7 +584,9 @@ public: } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 200005cd..ef916dde 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -61,7 +61,8 @@ template < typename Shape_, typename Element_, typename ThreadMap_, - conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray > class Conv2dDgradOutputGradientTileAccessIteratorOptimized; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -74,14 +75,16 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized; template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ > class Conv2dDgradOutputGradientTileAccessIteratorOptimized < Shape_, Element_, ThreadMap_, - conv::StrideSupport::kUnity -> { + conv::StrideSupport::kUnity, + AccessType_ +> { public: // @@ -93,7 +96,7 @@ public: using Layout = layout::TensorNHWC; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; @@ -101,7 +104,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + using Mask = uint64_t; // @@ -116,14 +124,13 @@ public: using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; // One pointer per access char const *pointer_[ThreadMap::Iterations::kStrided]; @@ -133,7 +140,7 @@ private: int filter_s_; int filter_k_; - Index masks_[ThreadMap::Iterations::kStrided][2]; + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; public: @@ -201,7 +208,11 @@ public: int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); - masks_[s_idx][0] |= (pred << r); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } } } @@ -218,12 +229,17 @@ public: int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; bool pred = (q >= 0 && q < problem_size_.Q); - masks_[s_idx][1] |= (pred << s); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } } } - if (filter_k_ >= problem_size.K) { - clear_mask(); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_k_ >= problem_size.K); } set_iteration_index(0); @@ -269,62 +285,15 @@ private: } } - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask_(bool clear) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - // We are using inline PTX assembly here to avoid an CUDA C++ compilation - // artifact in which control flow instructions are generated. Instead, our - // intent is to predicate the mov instructions. - #if defined(__CUDA_ARCH__) - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][0]) - : - "r"((int)clear), - "r"(masks_[s][0]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][1]) - : - "r"((int)clear), - "r"(masks_[s][1]) - ); - #else - if (clear) { - masks_[s][0] = 0; - masks_[s][1] = 0; - } - #endif - } - } - public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of element @@ -359,16 +328,32 @@ public: filter_k_ += params_.filter_k_delta; } - clear_mask_(filter_k_ >= problem_size_.K); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); + } } /// Clears the predicates CUTLASS_HOST_DEVICE - void clear_mask() { + void clear_mask(bool clear = true) { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][0] = Mask(0); - masks_[s][1] = Mask(0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; } } @@ -376,20 +361,25 @@ public: bool valid() { return - (masks_[iteration_strided_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][1] & (Index(1) << filter_s_)); + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { - return reinterpret_cast(pointer_[iteration_strided_]); + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { @@ -416,7 +406,7 @@ public: } // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorNotSupported; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index 4476be4f..3dba1a42 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dFpropActivationTileAccessIteratorAnalytic { public: @@ -74,7 +75,7 @@ public: using Layout = Layout_; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; @@ -82,7 +83,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + // // Simplifying assertions // @@ -95,14 +101,13 @@ public: using Params = Conv2dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; int filter_c_; @@ -156,8 +161,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -214,7 +221,7 @@ public: return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; + (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C; } /// Returns a pointer to the vector starting at the current coordinate @@ -224,7 +231,7 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; return ptr; } @@ -232,6 +239,12 @@ public: /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -252,7 +265,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 15839b01..4ca15ae9 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -61,7 +61,7 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - int AccessSize = ThreadMap_::kElementsPerAccess + typename AccessType_ = cutlass::AlignedArray > class Conv2dFpropActivationTileAccessIteratorOptimized { public: @@ -75,7 +75,7 @@ public: using Layout = Layout_; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; @@ -86,6 +86,11 @@ public: using Mask = uint64_t; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + // // Simplifying assertions // @@ -98,8 +103,6 @@ public: using Params = Conv2dFpropActivationIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; @@ -213,10 +216,10 @@ public: } } - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); - } + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } set_iteration_index(0); } @@ -260,56 +263,7 @@ private: pointer_[s] += byte_offset; } } - - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask_(bool clear, int index) { - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - // We are using inline PTX assembly here to avoid an CUDA C++ compilation - // artifact in which control flow instructions are generated. Instead, our - // intent is to predicate the mov instructions. - #if defined(__CUDA_ARCH__) - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][index][0]) - : - "r"((int)clear), - "r"(masks_[s][index][0]) - ); - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(masks_[s][index][1]) - : - "r"((int)clear), - "r"(masks_[s][index][1]) - ); - #else - if (clear) { - masks_[s][index][0] = 0; - masks_[s][index][1] = 0; - } - #endif - } - } - public: /// Overrides the internal iteration index @@ -354,23 +308,33 @@ public: filter_c_ += params_.filter_c_delta; } - CUTLASS_PRAGMA_UNROLL - for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); - } + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } } - + /// Clears the predicates CUTLASS_HOST_DEVICE - void clear_mask() { + void clear_mask(bool clear = true) { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { CUTLASS_PRAGMA_UNROLL for (int v = 0; v < kAccessesPerVector; ++v) { - masks_[s][v][0] = Mask(0); - masks_[s][v][1] = Mask(0); + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; } } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; + } } CUTLASS_HOST_DEVICE @@ -396,7 +360,6 @@ public: if (iteration_vector_ < kAccessesPerVector) { return *this; } - iteration_vector_ = 0; ++iteration_contiguous_; @@ -419,7 +382,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessSize) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 46c01bd3..f815b707 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -59,7 +59,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dFpropFilterTileAccessIteratorAnalytic { public: @@ -72,7 +73,7 @@ public: using Element = Element_; using Layout = Layout_; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -81,7 +82,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + // // Simplifying assertions // @@ -94,14 +100,13 @@ public: using Params = Conv2dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; int filter_r_; @@ -142,8 +147,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -187,7 +194,7 @@ public: TensorCoord coord = at(); return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C; } /// Returns a pointer to the vector starting at the current coordinate @@ -197,12 +204,18 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -223,7 +236,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } @@ -250,5 +263,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 218292ae..1ff04531 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -61,7 +61,7 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - int AccessSize = ThreadMap_::kElementsPerAccess + typename AccessType_ = cutlass::AlignedArray > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -74,7 +74,7 @@ public: using Element = Element_; using Layout = Layout_; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -83,15 +83,18 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + // // Simplifying assertions // static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // @@ -170,6 +173,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { predicates_[v_idx] |= (pred << s); @@ -178,7 +182,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); } pointer_ += ( @@ -188,41 +192,11 @@ public: set_iteration_index(0); } - /// Clears the predicates - CUTLASS_HOST_DEVICE - void clear_mask_(bool clear, int index) { - // We are using inline PTX assembly here to avoid an CUDA C++ compilation - // artifact in which control flow instructions are generated. Instead, our - // intent is to predicate the mov instructions. - #if defined(__CUDA_ARCH__) - asm volatile( - "{\n" - " .reg .pred p;\n" - " .reg .u32 m;" - " mov.u32 m, %2;" - " setp.ne.b32 p, %1, 0;\n" - " @p mov.u32 m, 0;\n" - " mov.u32 %0, m;\n" - "}\n" - : - "=r"(predicates_[index]) - : - "r"((int)clear), - "r"(predicates_[index]) - ); - #else - if (clear) { - predicates_[index] = 0; - } - #endif - } - /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { iteration_vector_ = index % kAccessesPerVector; int residual_access = index / kAccessesPerVector; - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } @@ -246,15 +220,21 @@ public: next = params_.inc_next_c; filter_c_ += params_.filter_c_delta; } - + CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); } pointer_ += next; } + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + predicates_[v] = clear ? 0u : predicates_[v]; + } + /// Returns true if the current coordinate is within the filter tensor W CUTLASS_HOST_DEVICE bool valid() { @@ -274,7 +254,6 @@ public: if (iteration_vector_ < kAccessesPerVector) { return *this; } - iteration_vector_ = 0; ++iteration_contiguous_; @@ -301,7 +280,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessSize) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index bbdc80b8..72c52c2e 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -68,6 +68,7 @@ public: using Params = typename TileAccessIterator::Params; static int const kConvDim = TileAccessIterator::kConvDim; using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; /// Fragment object to be loaded or stored using Fragment = cutlass::Array< @@ -130,18 +131,20 @@ public: for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) { + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + cutlass::arch::global_load< AccessType, sizeof(AccessType) >( - frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v], + frag_ptr[idx], tile_access_iterator_.get() + pointer_offset, tile_access_iterator_.valid() ); - + ++tile_access_iterator_; } } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index 2bfc527a..901582e8 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -58,7 +58,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dWgradActivationTileAccessIteratorAnalytic { public: @@ -70,7 +71,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -79,7 +80,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); @@ -89,14 +95,13 @@ public: using Params = Conv2dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k @@ -149,8 +154,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -173,9 +180,19 @@ public: /// by the iterator. CUTLASS_HOST_DEVICE TensorCoord at() const { + int r, s, c; - int r = filter_r_[iteration_contiguous_]; - int s = filter_s_[iteration_contiguous_]; + if (kAccessesPerVector == 1) { + r = filter_r_[iteration_contiguous_]; + s = filter_s_[iteration_contiguous_]; + c = filter_c_[iteration_contiguous_]; + } else { + c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; + int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; + s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; + int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; + r = filter_r_[iteration_contiguous_] + wrap_s; + } if (problem_size_.mode == Mode::kConvolution) { r = (problem_size_.R - 1 - r); @@ -184,14 +201,14 @@ public: int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); - + int p = residual / problem_size_.Q; int q = residual % problem_size_.Q; - + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; - - return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]); + + return TensorCoord(n, h, w, c); } /// Returns true if the current coordinate is within the activation tensor x @@ -201,8 +218,7 @@ public: return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; + coord.w() >= 0 && coord.w() < problem_size_.W; } /// Returns a pointer to the vector starting at the current coordinate @@ -218,6 +234,12 @@ public: /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -237,13 +259,12 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index 09cf6ffa..cb96594b 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -57,7 +57,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dWgradActivationTileAccessIteratorOptimized { public: @@ -69,7 +70,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -78,7 +79,12 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); @@ -88,14 +94,13 @@ public: using Params = Conv2dWgradActivationIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k @@ -153,9 +158,8 @@ public: s = (problem_size_.S - 1 - s); } - precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; - precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; - + precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; + precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; } // initialize n, p, q offset for every strided iteration @@ -170,8 +174,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -194,6 +200,31 @@ public: /// by the iterator. CUTLASS_HOST_DEVICE TensorCoord at() const { + int r = precomputed_filter_r_[iteration_contiguous_]; + int s = precomputed_filter_s_[iteration_contiguous_]; + int c = filter_c_[iteration_contiguous_]; + + if (kAccessesPerVector > 1) { + int wrap_c; + params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); + + if (problem_size_.mode == Mode::kConvolution) { + s -= (problem_size_.dilation_w * wrap_c); + + int wrap_s = (s == -problem_size_.pad_w - problem_size_.dilation_w); + s = wrap_s ? (-problem_size_.pad_w + (problem_size_.S - 1) * problem_size_.dilation_w): s; + + r -= (problem_size_.dilation_h * wrap_s); + + } else { + s += (problem_size_.dilation_w * wrap_c); + + int wrap_s = (s == (-problem_size_.pad_w + problem_size_.S * problem_size_.dilation_w)); + s = wrap_s ? -problem_size_.pad_w : s; + + r += (problem_size_.dilation_h * wrap_s); + } + } // The subseqnet fast_divmod() operations are equivalent to the following logical computation: // @@ -209,10 +240,10 @@ public: params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); params_.q_divmod(p, q, residual); - int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_]; - int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; + int h = p * problem_size_.stride_h + r; + int w = q * problem_size_.stride_w + s; - return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]); + return TensorCoord(n, h, w, c); } /// Returns true if the current coordinate is within the activation tensor x @@ -222,8 +253,7 @@ public: return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && - coord.w() >= 0 && coord.w() < problem_size_.W && - coord.c() < problem_size_.C; + coord.w() >= 0 && coord.w() < problem_size_.W; } /// Returns a pointer to the vector starting at the current coordinate @@ -239,6 +269,12 @@ public: /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -258,14 +294,14 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index c9e12699..e43bc534 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -58,7 +58,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { public: @@ -70,7 +71,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -80,6 +81,11 @@ public: static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); @@ -89,14 +95,13 @@ public: using Params = Conv2dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; int filter_k_[ThreadMap::Iterations::kContiguous]; @@ -143,8 +148,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -187,7 +194,7 @@ public: return coord.n() < problem_size_.N && coord.h() < problem_size_.P && coord.w() < problem_size_.Q && - coord.c() < problem_size_.K; + (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K; } /// Returns a pointer to the vector starting at the current coordinate @@ -197,12 +204,18 @@ public: TensorCoord coord = at(); LongIndex offset = params_.layout(coord); - return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -222,14 +235,14 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock @@ -237,5 +250,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index 8a8efb14..deacd8d7 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -57,7 +57,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray > class Conv2dWgradOutputGradientTileAccessIteratorOptimized { public: @@ -69,7 +70,7 @@ public: using Element = Element_; using Layout = layout::TensorNHWC; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + using AccessType = AccessType_; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -79,6 +80,11 @@ public: static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); @@ -88,17 +94,16 @@ public: using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; - uint32_t predicates_; + uint32_t predicates_[kAccessesPerVector]; int filter_k_; int offset_npq_; @@ -115,7 +120,7 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - predicates_(0), + predicates_{0}, filter_k_(0), offset_npq_(0) { @@ -132,13 +137,16 @@ public: int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; - bool predicate = valid_(at_(offset_npq, filter_k)); - - uint32_t pred = (predicate ? 1u : 0); - - int pred_idx = c + s * ThreadMap::Iterations::kContiguous; - - predicates_ |= (pred << pred_idx); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); + + uint32_t pred = (predicate ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } } } @@ -165,8 +173,10 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -185,7 +195,11 @@ public: for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); - predicates_ = (predicates_ & (~kClearMask)); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } } } @@ -231,7 +245,7 @@ public: bool valid() const { LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; - return (predicates_ & (1u << pred_idx)); + return (predicates_[iteration_vector_] & (1u << pred_idx)); } /// Returns a pointer to the vector starting at the current coordinate @@ -242,12 +256,18 @@ public: pointer_ + iteration_strided_ * params_.offset_next_strided + iteration_contiguous_ * params_.offset_next_contiguous - ); + ) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -267,14 +287,14 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock @@ -282,5 +302,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index 20296109..f21cec8c 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -79,11 +79,10 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or larger."); - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure @@ -261,5 +260,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index 8dc0a860..e65feb4c 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -82,8 +82,7 @@ public: static StrideSupport const kStrideSupport = StrideSupport_; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static int const kAccessesPerVector = 1; // // Parameters structure @@ -217,7 +216,8 @@ public: CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { - uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + predicates_ = (predicates_ & (~kClearMask)); } } @@ -281,5 +281,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 212637a4..2eff5751 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -93,11 +93,10 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or greater."); - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Simpligying assertions @@ -328,11 +327,11 @@ public: } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace conv } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 4333f343..86854354 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -86,7 +86,7 @@ public: static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; using Coord3D = Coord<3>; - + static int const kAccessesPerVector = 1; using Mask = uint64_t; // @@ -101,8 +101,6 @@ public: using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; @@ -403,7 +401,6 @@ public: } clear_mask_(filter_k_ >= problem_size_.K); - } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index 7ba3eb1b..009d0b1a 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -81,6 +81,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; // // Simplifying assertions @@ -94,8 +95,6 @@ public: using Params = Conv3dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index 7193d2a9..25b5d3cf 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -82,7 +82,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; - + static int const kAccessesPerVector = 1; using Mask = uint64_t; // @@ -97,8 +97,6 @@ public: using Params = Conv3dFpropActivationIteratorOptimizedParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 1ce9dba9..b56cee75 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -80,6 +80,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; // // Simplifying assertions @@ -93,8 +94,6 @@ public: using Params = Conv3dAnalyticParams; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index 8dd07c60..c19de1af 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -82,6 +82,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; // // Simplifying assertions @@ -89,8 +90,6 @@ public: static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // @@ -156,7 +155,7 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - predicates_(0), + predicates_{0}, filter_trs_(0), filter_c_(0) { diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index e154ef4e..039272df 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -79,11 +79,11 @@ public: static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 2960fcce..c9a0df93 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -79,12 +79,10 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; - + static int const kAccessesPerVector = 1; static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index 153b37f6..2779b1c9 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -78,12 +78,10 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; - + static int const kAccessesPerVector = 1; static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index 64d5e79f..9ba29b06 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -79,12 +79,10 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; - + static int const kAccessesPerVector = 1; static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index 58c7a409..03b4c1ab 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -216,6 +216,7 @@ public: for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); + ++iterator_A; } @@ -244,6 +245,7 @@ public: for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); + ++iterator_B; } ++this->smem_iterator_B_; @@ -289,16 +291,17 @@ public: CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); ++iterator_A; } + ++this->smem_iterator_A_; } @@ -313,17 +316,18 @@ public: this->smem_iterator_B_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - + cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); - + ++iterator_B; } + ++this->smem_iterator_B_; } diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu index 856f89bc..2af98674 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu @@ -39,6 +39,7 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128_64x3_64x64x64) { @@ -80,6 +81,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens } //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128_64x3_64x64x64) { @@ -121,4 +123,115 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten } //////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, + 128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity, + 4, + 4 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, + 128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity, + 4, + 4 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu index 277abb32..8dfb2d5c 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu @@ -38,46 +38,7 @@ #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// -TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, - 128x128_32x2_64x64x32) { - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, cutlass::layout::TensorNHWC, - ElementB, cutlass::layout::TensorNHWC, - ElementC, cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic, - cutlass::conv::StrideSupport::kUnity - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} - -//////////////////////////////////////////////////////////////////////////////// TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, 128x128_32x2_64x64x32) { @@ -118,6 +79,7 @@ TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens } //////////////////////////////////////////////////////////////////////////////// + TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, 128x128_32x2_64x64x32) { @@ -158,4 +120,113 @@ TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten } //////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity, + 2, + 2 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity, + 2, + 2 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu index 66278eeb..04474f2f 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu @@ -38,6 +38,7 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128_64x3_64x64x64) { @@ -78,6 +79,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens } //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128_64x3_64x64x64) { @@ -118,9 +120,10 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten } //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, + +TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, 128x128_64x3_64x64x64) { - + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -141,28 +144,41 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, 2, 2 >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, + +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, 128x128_64x3_64x64x64) { - + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -183,7 +199,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, @@ -191,14 +207,81 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten 3, cutlass::arch::OpMultiplyAdd, cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 2, + 2 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, + 128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 8, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, 4, 4 >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu index 382eea83..c5710dc0 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu @@ -119,49 +119,6 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten //////////////////////////////////////////////////////////////////////////////// -TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align1, - 128x128_32x2_64x64x32) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, cutlass::layout::TensorNHWC, - ElementB, cutlass::layout::TensorNHWC, - ElementC, cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, - 1, - 1 - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} - -//////////////////////////////////////////////////////////////////////////////// - TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2, 128x128_32x2_64x64x32) { @@ -185,7 +142,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten cutlass::gemm::GemmShape<16, 8, 8>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, @@ -193,14 +150,26 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten 2, cutlass::arch::OpMultiplyAdd, cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, 2, 2 >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } //////////////////////////////////////////////////////////////////////////////// @@ -228,7 +197,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten cutlass::gemm::GemmShape<16, 8, 8>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, @@ -236,15 +205,83 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten 2, cutlass::arch::OpMultiplyAdd, cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, 4, 4 >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } //////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 8, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu index d2d29115..090c4244 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -60,7 +60,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te cutlass::gemm::GemmShape<16, 8, 8>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, @@ -79,53 +79,9 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1, - 128x128_32x3_64x64x32) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::tfloat32_t; - using ElementB = cutlass::tfloat32_t; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, cutlass::layout::TensorNHWC, - ElementB, cutlass::layout::TensorNHWC, - ElementC, cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 16>, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, - 1, - 1 - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} - -//////////////////////////////////////////////////////////////////////////////// - TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2, 128x128_32x3_64x64x32) { - + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) using ElementA = cutlass::tfloat32_t; using ElementB = cutlass::tfloat32_t; @@ -146,7 +102,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t cutlass::gemm::GemmShape<16, 8, 8>, cutlass::epilogue::thread::LinearCombination< ElementC, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementCompute >, @@ -154,15 +110,26 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t 3, cutlass::arch::OpMultiplyAdd, cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, 2, 2 >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index 3b36fd6c..d8a3ea10 100644 --- a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -174,6 +174,61 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); } + //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 809c8e7d..e2b41233 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -116,6 +116,7 @@ public: cutlass::Distribution::Kind dist_kind, uint64_t seed) { +//cutlass::reference::host::TensorFill(view, Element(1.0f)); if (dist_kind == cutlass::Distribution::Uniform) { int scope; diff --git a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu index 44a721e5..9783f94b 100644 --- a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu +++ b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu @@ -36,6 +36,8 @@ #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) +//////////////////////////////////////////////////////////////////////////////// + TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_32x2_64x64x32) { @@ -74,5 +76,114 @@ TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens } //////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED diff --git a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index 13f6fc87..3180ffae 100644 --- a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -146,8 +146,7 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, - cutlass::conv::StrideSupport::kStrided + cutlass::conv::IteratorAlgorithm::kOptimized >::Kernel; using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; @@ -157,5 +156,113 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten } //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED +TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 4, 4, 12}, // input size (NHWC) + {8, 3, 3, 12}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/tools/library/scripts/conv2d_operation.py b/tools/library/scripts/conv2d_operation.py index b6757072..11d7f384 100644 --- a/tools/library/scripts/conv2d_operation.py +++ b/tools/library/scripts/conv2d_operation.py @@ -103,9 +103,9 @@ class Conv2dOperation: ) if self.stride_support == StrideSupport.Unity: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride" + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}_unity_stride" else: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}" + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" return SubstituteTemplate( configuration_name, @@ -114,6 +114,7 @@ class Conv2dOperation: 'extended_name': self.extended_name(), 'threadblock': threadblock, 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, } ) @@ -156,7 +157,9 @@ class EmitConv2dInstance: ${stages}, ${math_operator}, ${iterator_algorithm}, - ${stride_support} + ${stride_support}, + ${align_a}, + ${align_b} >::Kernel; """ @@ -198,7 +201,9 @@ class EmitConv2dInstance: 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), 'stride_support': StrideSupportTag[operation.stride_support], 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \ - MathOperationTag[operation.tile_description.math_instruction.math_operation] + MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), } return SubstituteTemplate(self.template, values) @@ -341,4 +346,3 @@ void initialize_${configuration_name}(Manifest &manifest) { ################################################################################################### ################################################################################################### - diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 57d2a4ca..e5bc3464 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -151,14 +151,13 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low ########################################################################################################### # Convolution for 2D operations -def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ +def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): element_a, element_b, element_c, element_epilogue = data_type # one exceptional case - alignment_c = min(8, alignment) # iterator algorithm (analytic and optimized) iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] @@ -166,66 +165,71 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme # by default, only generate the largest tile size if manifest.args.kernels == '': tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] operations = [] for tile in tile_descriptions: - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor + for alignment in alignment_constraints: - # - # Conv2d Fprop - # - if ConvKind.Fprop in conv_kinds: + alignment_c = min(8, alignment) - # Strided support for Analytic and Optimized Fprop - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - # - # Conv2d Dgrad - # - if ConvKind.Dgrad in conv_kinds: - - # Unity stride for Analytic and Optimized Dgrad - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - # Strided support for Analytic Dgrad - # strided dgrad uses a special threadblock swizzle - # note that SwizzlingFunctor.StridedDgradHorizontal might be - # better for problem sizes with large activation channel count - swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 - - new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) - - manifest.append(new_operation) - operations.append(new_operation) + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) - # - # Conv2d Wgrad - # - if ConvKind.Wgrad in conv_kinds: - - # Strided support for Analytic and Optimized Wgrad - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + + # Unity stride for Analytic and Optimized Dgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # strided dgrad uses a special threadblock swizzle + # note that SwizzlingFunctor.StridedDgradHorizontal might be + # better for problem sizes with large activation channel count + swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 + + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + manifest.append(new_operation) operations.append(new_operation) + + # + # Conv2d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) return operations @@ -322,7 +326,7 @@ def GenerateSM50_Simt(manifest, args): if math_inst.element_a == DataType.f32: conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # # @@ -369,7 +373,7 @@ def GenerateSM50_Simt_complex(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # # @@ -543,7 +547,7 @@ def GenerateSM70_TensorOp_884(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -558,7 +562,7 @@ def GenerateSM70_TensorOp_884(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) # def GenerateSM70_PlanarComplexTensorOp_884(manifest, args): @@ -754,7 +758,7 @@ def GenerateSM75_TensorOp_1688(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -769,7 +773,7 @@ def GenerateSM75_TensorOp_1688(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) # @@ -891,7 +895,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args): conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -909,7 +913,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args): data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: if op.tile_description.threadblock_shape[1] >= 128: @@ -972,7 +976,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args): conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: op.C.alignment = 8 @@ -1028,7 +1032,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args): conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -1046,7 +1050,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args): data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: if op.tile_description.threadblock_shape[1] >= 128: @@ -1112,7 +1116,7 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: op.C.alignment = 16 @@ -1250,7 +1254,7 @@ def GenerateSM75_Simt_complex(manifest, args): ] conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # def GenerateSM75(manifest, args): @@ -1338,7 +1342,7 @@ def GenerateSM80_TensorOp_16816(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) @@ -1354,7 +1358,7 @@ def GenerateSM80_TensorOp_16816(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8) # @@ -1572,10 +1576,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args): conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: if op.tile_description.threadblock_shape[1] >= 128: @@ -1689,7 +1693,7 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args): conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: op.C.alignment = 8 @@ -1758,10 +1762,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: if op.tile_description.threadblock_shape[1] >= 128: @@ -1878,7 +1882,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args): conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: op.C.alignment = 16 @@ -2005,9 +2009,9 @@ def GenerateSM80_TensorOp_1688(manifest, args): data_type_mixed, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) # # @@ -2076,7 +2080,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # # @@ -2366,7 +2370,7 @@ def GenerateSM80_Simt_f32(manifest, args): data_type, alignment_constraints) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # @@ -2467,7 +2471,7 @@ def GenerateSM80_Simt_complex(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) # ###################################################################################################