/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief This file contains definitions and utility functions for describing convolution problem shapes. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/tensor_coord.h" #include "cutlass/conv/convolution.h" #include "cute/container/array.hpp" #if ! defined(__CUDACC_RTC__) #include #endif //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::conv { //////////////////////////////////////////////////////////////////////////////////////////////////// // Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion. // All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK) // Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. template < conv::Operator ConvOp_, int NumSpatialDimensions > struct ConvProblemShape { // // Alias types for members // static constexpr int RankS = NumSpatialDimensions; static constexpr int RankT = NumSpatialDimensions + 2; static constexpr conv::Operator ConvOp = ConvOp_; using SpatialExtent = cute::array; using TensorExtent = cute::array; using TensorStride = cute::array; using ShapePadding = SpatialExtent; using TraversalStride = SpatialExtent; using ShapeDilation = SpatialExtent; using Corner = SpatialExtent; // // Members // cutlass::conv::Mode mode{}; TensorExtent shape_A{}; TensorStride stride_A{}; TensorExtent shape_B{}; TensorStride stride_B{}; TensorExtent shape_C{}; TensorStride stride_C{}; // asymmetric padding, both upper and lower padding must be >= 0 ShapePadding lower_padding{}; ShapePadding upper_padding{}; TraversalStride traversal_stride{}; ShapeDilation dilation{}; int groups = 1; // // Methods // ConvProblemShape() = default; // Constructor accepts user facing arguments and computes to stores the corners as its internal state ConvProblemShape( conv::Mode mode, // convolution/cross-correlation TensorExtent shape_act, // [n,d,h,w,c] TensorStride stride_act, // [n,d,h,w,c] TensorExtent shape_flt, // [k,t,r,s,c] TensorStride stride_flt, // [k,t,r,s,c] ShapePadding lower_padding, // [pad_d, pad_h, pad_w] ShapePadding upper_padding, // [pad_d, pad_h, pad_w] TraversalStride tstride, // [stride_d, stride_h, stride_w] ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] int groups) : mode(mode) , lower_padding(lower_padding) , upper_padding(upper_padding) , traversal_stride(tstride) , dilation(dilation) , groups(groups) { auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); } // Allow user input of xformed activation stride to support non-packed strides. ConvProblemShape( conv::Mode mode, // convolution/cross-correlation TensorExtent shape_act, // [n,d,h,w,c] TensorStride stride_act, // [n,d,h,w,c] TensorExtent shape_flt, // [k,t,r,s,c] TensorStride stride_flt, // [k,t,r,s,c] TensorStride stride_xformed_act, // [n,z,p,q,k] ShapePadding lower_padding, // [pad_d, pad_h, pad_w] ShapePadding upper_padding, // [pad_d, pad_h, pad_w] TraversalStride tstride, // [stride_d, stride_h, stride_w] ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] int groups) : mode(mode) , lower_padding(lower_padding) , upper_padding(upper_padding) , traversal_stride(tstride) , dilation(dilation) , groups(groups) { CUTLASS_ASSERT(stride_act[RankT - 1] == 1); CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); auto stride_act_packed = packed_stride_right_major(shape_act); auto stride_flt_packed = packed_stride_right_major(shape_flt); auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); CUTLASS_PRAGMA_UNROLL for(int i = 0; i < RankT - 1; ++i) { CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); } set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); } // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order. ConvProblemShape( conv::Mode mode, TensorExtent shape_act, TensorExtent shape_flt, ShapePadding lower_padding, ShapePadding upper_padding, TraversalStride tstride, ShapeDilation dilation, int groups) : ConvProblemShape( mode, shape_act, packed_stride_right_major(shape_act), shape_flt, packed_stride_right_major(shape_flt), lower_padding, upper_padding, tstride, dilation, groups) { } #if ! defined(__CUDACC_RTC__) // Constructor accepts user facing arguments and computes to stores the corners as its internal state ConvProblemShape( conv::Mode mode, std::initializer_list shape_act_, std::initializer_list stride_act_, std::initializer_list shape_flt_, std::initializer_list stride_flt_, std::initializer_list lower_padding_, std::initializer_list upper_padding_, std::initializer_list traversal_stride_, std::initializer_list dilation_, int groups) : mode(mode) , groups(groups) { TensorExtent shape_act{}; TensorStride stride_act{}; TensorExtent shape_flt{}; TensorStride stride_flt{}; assert(shape_act_.size() == shape_act.size()); assert(stride_act_.size() == stride_act.size()); assert(shape_flt_.size() == shape_flt.size()); assert(stride_flt_.size() == stride_flt.size()); assert(lower_padding_.size() == lower_padding.size()); assert(upper_padding_.size() == upper_padding.size()); assert(traversal_stride_.size() == traversal_stride.size()); assert(dilation_.size() == dilation.size()); std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); } // Allow user input of xformed activation stride to support non-packed strides. ConvProblemShape( conv::Mode mode, std::initializer_list shape_act_, std::initializer_list stride_act_, std::initializer_list shape_flt_, std::initializer_list stride_flt_, std::initializer_list stride_xformed_act_, std::initializer_list lower_padding_, std::initializer_list upper_padding_, std::initializer_list traversal_stride_, std::initializer_list dilation_, int groups) : mode(mode) , groups(groups) { TensorExtent shape_act{}; TensorStride stride_act{}; TensorExtent shape_flt{}; TensorStride stride_flt{}; TensorStride stride_xformed_act{}; std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin()); std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); CUTLASS_ASSERT(stride_act[RankT - 1] == 1); CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); auto stride_act_packed = packed_stride_right_major(shape_act); auto stride_flt_packed = packed_stride_right_major(shape_flt); auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); CUTLASS_PRAGMA_UNROLL for(int i = 0; i < RankT - 1; ++i) { CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); } set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); } // Constructor accepts user facing arguments and computes to stores the corners as its internal state ConvProblemShape( conv::Mode mode, std::initializer_list shape_act_, std::initializer_list shape_flt_, std::initializer_list lower_padding_, std::initializer_list upper_padding_, std::initializer_list traversal_stride_, std::initializer_list dilation_, int groups) : mode(mode) , groups(groups) { TensorExtent shape_act{}; TensorStride stride_act{}; TensorExtent shape_flt{}; TensorStride stride_flt{}; assert(shape_act_.size() == shape_act.size()); assert(shape_flt_.size() == shape_flt.size()); assert(lower_padding_.size() == lower_padding.size()); assert(upper_padding_.size() == upper_padding.size()); assert(traversal_stride_.size() == traversal_stride.size()); assert(dilation_.size() == dilation.size()); std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); stride_act = packed_stride_right_major(shape_act); stride_flt = packed_stride_right_major(shape_flt); auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); } #endif // not defined(__CUDACC_RTC__) // Set shape and stride of tensor A/B/C according to following table: // | | Fprop | Dgrad | Wgrad | // | ------ | ------ | ------ | ------| // | ShapeA | NDHWC | NZPQK | NZPQK | // | ShapeB | KTRSC | KTRSC | NDHWC | // | ShapeC | NZPQK | NDHWC | KTRSC | // CUTLASS_HOST_DEVICE constexpr void set_shape_stride_ABC( TensorExtent shape_act, TensorStride stride_act, TensorExtent shape_flt, TensorStride stride_flt, TensorExtent shape_xformed_act, TensorStride stride_xformed_act) { if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { shape_A = shape_act; stride_A = stride_act; shape_B = shape_flt; stride_B = stride_flt; shape_C = shape_xformed_act; stride_C = stride_xformed_act; } else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { shape_A = shape_xformed_act; stride_A = stride_xformed_act; shape_B = shape_flt; stride_B = stride_flt; shape_C = shape_act; stride_C = stride_act; } else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { shape_A = shape_xformed_act; stride_A = stride_xformed_act; shape_B = shape_act; stride_B = stride_act; shape_C = shape_flt; stride_C = stride_flt; } } // Get problem shape MNK according to following table: // | | Fprop | Dgrad | Wgrad | // | ---- | --------- | -------- | -------- | // | Shape_M | (Q,P,Z,N) | (W,H,D,N) | (K) | // | Shape_N | (K) | (C) | (C,S,R,T) | // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | CUTLASS_HOST_DEVICE constexpr auto get_transformed_problem_shape_MNK() const { using cute::insert; using cute::make_shape; using cute::reverse; using cute::take; if constexpr (ConvOp == conv::Operator::kWgrad) { auto M_xformed = shape_C[0]; auto N_xformed = reverse(take<1, RankT>(shape_C)); auto K_xformed = reverse(take<0, RankT - 1>(shape_A)); return make_shape(M_xformed, N_xformed, K_xformed); } else if constexpr (ConvOp == conv::Operator::kFprop){ auto M_xformed = reverse(take<0, RankT - 1>(shape_C)); auto N_xformed = shape_C[RankT - 1]; auto K_xformed = reverse(take<1, RankT>(shape_B)); return make_shape(M_xformed, N_xformed, K_xformed); } else if constexpr (ConvOp == conv::Operator::kDgrad) { auto M_xformed = reverse(take<0,RankT - 1>(shape_C)); auto N_xformed = shape_C[RankT - 1]; // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] auto K_xformed = insert<0>( (reverse(take<1,RankT - 1>(shape_B))), shape_B[0]); return make_shape(M_xformed, N_xformed, K_xformed); } } // Get A extents. // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) CUTLASS_HOST_DEVICE constexpr auto get_shape_A() const { using cute::make_shape; using cute::take; if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { return make_shape( cute::reverse(take<0, RankT - 1>(shape_A)), shape_A[RankT - 1]); } // For wgrad kernel, we need to linearize NZPQ for tensor A else if constexpr (ConvOp == conv::Operator::kWgrad) { return make_shape( shape_A[RankT - 1], cute::product(take<0, RankT - 1>(shape_A))); } } // Get B extents. // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) CUTLASS_HOST_DEVICE constexpr auto get_shape_B() const { using cute::make_shape; using cute::reverse; using cute::take; if constexpr (ConvOp == conv::Operator::kFprop) { return make_shape( shape_B[0], reverse(take<1, RankT>(shape_B))); } else if constexpr (ConvOp == conv::Operator::kWgrad) { return make_shape( shape_B[RankT - 1], reverse(take<0, RankT - 1>(shape_B))); } else if constexpr (ConvOp == conv::Operator::kDgrad) { // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)] return make_shape( shape_B[RankT - 1], cute::insert<0>( reverse(take<1, RankT - 1>(shape_B)), shape_B[0])); } } // Static method that returns the canonical strides of tensors (layouts are right major and compact) CUTLASS_HOST_DEVICE static constexpr TensorStride packed_stride_right_major(TensorExtent const& extents) { TensorStride strides{}; strides[RankT-1] = 1; cute::for_each(cute::make_rseq{}, [&](auto i) { strides[i] = extents[i+1] * strides[i+1]; }); return strides; } // Static method that returns the packed logical size of any TensorExtent CUTLASS_HOST_DEVICE static constexpr size_t size(TensorExtent const& extents) { size_t size = 1; cute::for_each(cute::make_seq{}, [&](auto i) { size *= extents[i]; }); return size; } CUTLASS_HOST_DEVICE constexpr size_t size_A() const { return shape_A[0] * stride_A[0]; } CUTLASS_HOST_DEVICE constexpr size_t size_B() const { return shape_B[0] * stride_B[0]; } CUTLASS_HOST_DEVICE constexpr size_t size_C() const { return shape_C[0] * stride_C[0]; } // Equality operator CUTLASS_HOST_DEVICE bool operator==(ConvProblemShape const& rhs) const { using cute::for_each; using cute::make_seq; bool is_equal = true; // Compare all tensor extents for_each(make_seq{}, [&](auto i) { is_equal = is_equal && (shape_A[i] == rhs.shape_A[i]) && (shape_B[i] == rhs.shape_B[i]); }); // Compare all spatial extents for_each(make_seq{}, [&](auto i) { is_equal = is_equal && (lower_padding[i] == rhs.lower_padding[i]) && (upper_padding[i] == rhs.upper_padding[i]) && (traversal_stride[i] == rhs.traversal_stride[i]) && (dilation[i] == rhs.dilation[i]); }); return is_equal; } /// Inequality operator CUTLASS_HOST_DEVICE bool operator!=(ConvProblemShape const &rhs) const { return !(*this == rhs); } private: CUTLASS_HOST_DEVICE constexpr auto calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) { TensorExtent shape_xformed_act{}; // calculate n,z,p,q,k. // a helper lambda to compute a single spatial extent of the nzpqk tensor auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; }; shape_xformed_act[0] = shape_act[0]; // Activation N extent cute::for_each(cute::make_seq{}, [&](auto i) { shape_xformed_act[i+1] = nzpqk_extent( shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]); }); shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act); return cute::make_tuple(shape_xformed_act, stride_xformed_act); } }; template< conv::Operator ConvOp, int SpatialDim > void print(ConvProblemShape const& problem) { printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n", SpatialDim, int(ConvOp)); printf("\tTensorA: "); cute::print(problem.shape_A); printf(":"); cute::print(problem.stride_A); printf("\n"); printf("\tTensorB: "); cute::print(problem.shape_B); printf(":"); cute::print(problem.stride_B); printf("\n"); printf("\tTensorC: "); cute::print(problem.shape_C); printf(":"); cute::print(problem.stride_C); printf("\n"); printf("\tLower padding: "); print(problem.lower_padding); printf("\n"); printf("\tUpper padding: "); print(problem.upper_padding); printf("\n"); printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n"); printf("\tDilation: "); print(problem.dilation); printf("\n"); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv ////////////////////////////////////////////////////////////////////////////////////////////////////