diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index c23114a0..a4163d97 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -169,7 +169,6 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementAccumulator, // Data type of accumulator ElementComputeEpilogue>; // Data type for alpha/beta in linear combination - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, @@ -592,7 +591,7 @@ Result profile_convolution(Options const &options) { std::cout << "Results written to '" << ss.str() << "'." << std::endl; } - + // // Performance measurement // diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 96a0ee40..26071666 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -61,17 +61,16 @@ struct ReLu { static const bool kIsHeavy=false; CUTLASS_HOST_DEVICE T operator()(T const & threshold, T value) const { - if (value < threshold) { - value = threshold; - } - return value; + maximum mx; + + return mx(value, threshold); } + CUTLASS_HOST_DEVICE T operator()(T value) const { - if (value < T(0)) { - value = T(0); - } - return value; + maximum mx; + + return mx(value, T(0)); } }; @@ -80,32 +79,16 @@ struct ReLu> { static const bool kIsHeavy=false; CUTLASS_HOST_DEVICE Array operator()(T const & threshold, Array const &frag) const { - Array result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - T value = frag[i]; - if (value < threshold) { - value = threshold; - } - result[i] = value; - } - return result; + maximum > mx; + + return mx(threshold, frag); } CUTLASS_HOST_DEVICE Array operator()(Array const &frag) const { - Array result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - T value = frag[i]; - if (value < T(0)) { - value = T(0); - } - result[i] = value; - } - return result; + maximum > mx; + return mx(frag, T(0)); } - }; // Sigmoid operator @@ -125,7 +108,7 @@ struct Sigmoid > { Sigmoid sigmoid_op; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(rhs.size()); ++i) { + for (int i = 0; i < N; ++i) { y[i] = sigmoid_op(rhs[i]); } @@ -193,7 +176,20 @@ struct HardSwish { minimum mn; maximum mx; T relu6 = mn(mx(x + T(3), T(0)), T(6)); - return x * (relu6 / T(6)); + return x * relu6 / T(6); + } +}; + +template <> +struct HardSwish { + using T = float; + + CUTLASS_HOST_DEVICE + T operator()(T const &x) const { + minimum mn; + maximum mx; + T relu6 = mn(mx(x + T(3), T(0)), T(6)); + return x * relu6 * 0.16666667f; } }; @@ -213,6 +209,21 @@ struct HardSwish > { } }; +template +struct HardSwish > { + using T = half_t; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + minimum > mn; + maximum > mx; + multiplies > mul; + plus > add; + + return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f)); + } +}; + // // GELU function definitions implemented as described by // Hendrycks, D., and Gimpel, K. in diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index 7bd02526..bb98b0ec 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -162,6 +162,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -197,8 +199,15 @@ public: minimum min_accumulator; maximum max_accumulator; - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } /// Clamping constant value ElementCompute const kClampMax = @@ -235,7 +244,11 @@ public: minimum min_accumulator; maximum max_accumulator; - intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } /// Clamping constant value ElementCompute const kClampMax = @@ -367,6 +380,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -399,8 +414,15 @@ public: multiply_add mul_add_accumulator; // Float min-max - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -430,7 +452,11 @@ public: multiplies mul_add_accumulator; // Float min-max - intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -551,6 +577,8 @@ class FastLinearCombinationClamp { if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -586,10 +614,17 @@ class FastLinearCombinationClamp { maximum max_accumulator; // Float min-max - intermediate = - mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, - intermediate); // D = alpha * Accum + X + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, + intermediate); // D = alpha * Accum + X + } /// Clamping constant value ElementCompute const kClamp = @@ -624,7 +659,11 @@ class FastLinearCombinationClamp { maximum max_accumulator; // Float min-max - intermediate = mul_accumulator(alpha_, converted_accumulator); + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); + } /// Clamping constant value ElementCompute const kClamp = diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index 2bf05b7b..5ad986a0 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -51,11 +51,11 @@ template < ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > using LinearCombinationGELU = LinearCombinationGeneric; - + ElementCompute_, Scale, Round, true>; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index 17f961e8..1712a283 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -54,6 +54,7 @@ template < ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, bool IsHeavy = false > @@ -66,10 +67,11 @@ public: static bool const kIsHeavy = IsHeavy; static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; using FragmentOutput = Array; using FragmentAccumulator = Array; - using ComputeFragment = Array; + using FragmentCompute = Array; static FloatRoundStyle const kRound = Round; @@ -131,6 +133,12 @@ public: /// Returns true if source is needed CUTLASS_HOST_DEVICE bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -152,19 +160,26 @@ public: NumericArrayConverter source_converter; NumericArrayConverter accumulator_converter; - ComputeFragment converted_source = source_converter(source); - ComputeFragment converted_accumulator = accumulator_converter(accumulator); + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); // Perform binary operations - ComputeFragment intermediate; + FragmentCompute intermediate; - multiplies mul_add_source; - multiply_add mul_add_accumulator; - ActivationFunctor activation; + multiplies mul_add_source; + multiply_add mul_add_accumulator; + ActivationFunctor activation; - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } intermediate = activation(intermediate); @@ -182,16 +197,20 @@ public: // Convert source to interal compute numeric type NumericArrayConverter accumulator_converter; - ComputeFragment converted_accumulator = accumulator_converter(accumulator); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); // Perform binary operations - ComputeFragment intermediate; + FragmentCompute intermediate; - multiplies mul_add_accumulator; - ActivationFunctor activation; + multiplies mul_add_accumulator; + ActivationFunctor activation; - intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } intermediate = activation(intermediate); diff --git a/include/cutlass/epilogue/thread/linear_combination_hardswish.h b/include/cutlass/epilogue/thread/linear_combination_hardswish.h index e6c37d50..d6f57270 100644 --- a/include/cutlass/epilogue/thread/linear_combination_hardswish.h +++ b/include/cutlass/epilogue/thread/linear_combination_hardswish.h @@ -1,5 +1,5 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** +* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -51,10 +51,11 @@ template < ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > using LinearCombinationHardSwish = LinearCombinationGeneric; + ElementCompute_, Scale, Round>; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 6479c31e..7bd79a3b 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -156,6 +156,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -193,12 +195,15 @@ public: multiply_add mul_add_accumulator; ReLu relu; - if (Scale == ScaleType::NoBetaScaling) + if (Scale == ScaleType::NoBetaScaling) { intermediate = converted_source; - else + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } // Compute threshold optionally intermediate = relu(threshold_, intermediate); @@ -225,7 +230,11 @@ public: multiplies mul_accumulator; ReLu relu; - intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } // Compute threshold optionally intermediate = relu(threshold_, intermediate); @@ -269,8 +278,6 @@ public: return destination_converter(intermediate); } - - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -376,6 +383,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -413,12 +422,15 @@ public: multiply_add mul_add_accumulator; ReLu relu; - if (Scale == ScaleType::NoBetaScaling) + if (Scale == ScaleType::NoBetaScaling) { intermediate = converted_source; - else + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } // Compute threshold optionally intermediate = relu(threshold_, intermediate); @@ -459,7 +471,11 @@ public: multiplies mul_accumulator; ReLu relu; - intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } // Compute threshold optionally intermediate = relu(threshold_, intermediate); @@ -512,22 +528,13 @@ public: // Compute threshold optionally intermediate = relu(threshold_, intermediate); - if (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value) { + if (platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); // Convert to destination numeric type NumericArrayConverter @@ -540,7 +547,6 @@ public: return destination_converter(intermediate); } } - }; #endif // Conditional guards to enable partial specialization for packed integers diff --git a/include/cutlass/epilogue/thread/linear_combination_relu0.h b/include/cutlass/epilogue/thread/linear_combination_relu0.h new file mode 100644 index 00000000..0cb4f956 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_relu0.h @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a relu operation used by epilogues. + This one only supports relu0 and tries to folding relu into other instructions. Thus, + serial splitk is not supported by this one. For example, relu can be folded into + hfma2/hmul2 for sm80+ +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Single source of truth for whether to unroll for `LinearCombinationClamp()` +constexpr bool LinearCombinationRelu0IsHeavy() { + return false; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationRelu0 { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentScaleBias = Array; + + static FloatRoundStyle const kRound = Round; + + static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta = ElementCompute(0) + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr = nullptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationRelu0(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + + return beta_ != ElementCompute(0); + } + + /// This is used for serial reduction which is not supported by Relu0 + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(k_partition == 0); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add_relu0 mul_add_relu0_accumulator; + ReLu relu; + + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + + // Compute threshold optionally + intermediate = relu(intermediate); + } else { + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + ReLu relu; + + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } + + // Compute threshold optionally + intermediate = relu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias + /// Scale and Bias are from input Fragment + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentScaleBias const &scale, + FragmentScaleBias const &bias) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform per-channel scale and bias + FragmentCompute intermediate; + + multiply_add mul_add_accumulator; + + if(Scale == ScaleType::OnlyAlphaPerChannelScaling) + intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias + else + intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias + + ReLu relu; + + // Compute threshold optionally + intermediate = relu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conditional guards to enable partial specialization for packed integers +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Special handling for int types + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ScaleType::Kind Scale, ///< Control Alpha and Beta scaling + FloatRoundStyle Round +> +class LinearCombinationRelu0 { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = int; + using ElementCompute = float; + + static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); + + static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentScaleBias = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta = ElementCompute(0) + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr = nullptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationRelu0(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + + return beta_ != ElementCompute(0); + } + + /// This is used for serial reduction which is not supported by Relu0 + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(k_partition == 0); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + ReLu relu; + + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } + + // Compute threshold optionally + intermediate = relu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + ReLu relu; + + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } + + // Compute threshold optionally + intermediate = relu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } + + /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias + /// Scale and Bias are from input Fragment + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentScaleBias const &scale, + FragmentScaleBias const &bias) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform per-channel scale and bias + FragmentCompute intermediate; + + multiply_add mul_add_accumulator; + + if(Scale == ScaleType::OnlyAlphaPerChannelScaling) + intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias + else + intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias + + ReLu relu; + + // Compute threshold optionally + intermediate = relu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } +}; + +#endif // Conditional guards to enable partial specialization for packed integers + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h index e5ef55c8..eb5516ed 100644 --- a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +++ b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -51,10 +51,12 @@ template < ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > using LinearCombinationSigmoid = LinearCombinationGeneric; + ElementCompute_, Scale, Round, true>; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_silu.h b/include/cutlass/epilogue/thread/linear_combination_silu.h index e9a3e2c9..79ef792e 100644 --- a/include/cutlass/epilogue/thread/linear_combination_silu.h +++ b/include/cutlass/epilogue/thread/linear_combination_silu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -51,10 +51,11 @@ template < ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > using LinearCombinationSilu = LinearCombinationGeneric; + ElementCompute_, Scale, Round, true>; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 8b1d803f..177c3367 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -36,13 +36,17 @@ #include "cutlass/numeric_types.h" #include "cutlass/array.h" +#include "cutlass/platform/platform.h" + #include "cutlass/gemm/gemm.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" #include "cutlass/epilogue/thread/linear_combination_gelu.h" #include "cutlass/epilogue/thread/linear_combination_sigmoid.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" #include "cutlass/epilogue/thread/linear_combination_planar_complex.h" #include "cutlass/epilogue/thread/conversion_op.h" diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 0f3d4eb5..63f16c3e 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -224,6 +224,40 @@ struct less { } }; +template +struct maximum { + + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + return (lhs < rhs ? rhs : lhs); + } +}; + +template <> +struct maximum { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs, float const &rhs) const { + return fmaxf(lhs, rhs); + } +}; + +template +struct minimum { + + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + return (rhs < lhs ? rhs : lhs); + } +}; + +template <> +struct minimum { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs, float const &rhs) const { + return fminf(lhs, rhs); + } +}; + /// Fused multiply-add template struct multiply_add { @@ -233,6 +267,16 @@ struct multiply_add { } }; +/// Fused multiply-add +template +struct multiply_add_relu0 { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + maximum mx; + return mx(C(a) * C(b) + c, C(0)); + } +}; + /// Fused multiply-add template struct and_add { @@ -366,7 +410,6 @@ struct bit_or> { } }; - // Partial specializations for Arrays template struct bit_not> { @@ -560,139 +603,6 @@ struct plus> { return result; } }; - - -template -struct maximum { - - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - return (lhs < rhs ? rhs : lhs); - } -}; - -template <> -struct maximum { - CUTLASS_HOST_DEVICE - float operator()(float const &lhs, float const &rhs) const { - return fmaxf(lhs, rhs); - } -}; - -template -struct maximum> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum { - - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - return (rhs < lhs ? rhs : lhs); - } -}; - -template <> -struct minimum { - CUTLASS_HOST_DEVICE - float operator()(float const &lhs, float const &rhs) const { - return fminf(lhs, rhs); - } -}; - -template -struct minimum> { - - CUTLASS_HOST_DEVICE - static T scalar_op(T const &lhs, T const &rhs) { - return (rhs < lhs ? rhs : lhs); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - template struct minus> { @@ -831,6 +741,102 @@ struct divides> { } }; +template +struct maximum> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; template struct negate> { @@ -897,6 +903,56 @@ struct multiply_add, Array, Array> { } }; +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); + } + + return result; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Array targeting SIMD instructions in device code. @@ -1536,6 +1592,413 @@ struct multiply_add, Array, Array> { } }; +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma_relu( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a, b[i], c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b, c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c)); + } + #endif + + return result; + } +}; + +template +struct minimum> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs < lhs[i] ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + +template +struct maximum> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs < rhs[i] ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Fused multiply-add