[hardswish] correct implmentation (#403)

* [hardswish] correct implmentation

* seems working

* hardswish fp32/fp16x2 optimization

* [relu] half2 support

* add relu0; add multiply_add_relu0;

* cleanup

Co-authored-by: Bing Xu <bingxu@fb.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Bing Xu
2022-02-09 11:28:53 -08:00
committed by GitHub
parent 8a951b2940
commit d0d941efc7
12 changed files with 1315 additions and 235 deletions

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -169,7 +169,6 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
@ -592,7 +591,7 @@ Result profile_convolution(Options const &options) {
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -61,17 +61,16 @@ struct ReLu {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T value) const {
if (value < threshold) {
value = threshold;
}
return value;
maximum<T> mx;
return mx(value, threshold);
}
CUTLASS_HOST_DEVICE
T operator()(T value) const {
if (value < T(0)) {
value = T(0);
}
return value;
maximum<T> mx;
return mx(value, T(0));
}
};
@ -80,32 +79,16 @@ struct ReLu<Array<T, N>> {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
Array<T, N> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
T value = frag[i];
if (value < threshold) {
value = threshold;
}
result[i] = value;
}
return result;
maximum<Array<T, N> > mx;
return mx(threshold, frag);
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
Array<T, N> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
T value = frag[i];
if (value < T(0)) {
value = T(0);
}
result[i] = value;
}
return result;
maximum<Array<T, N> > mx;
return mx(frag, T(0));
}
};
// Sigmoid operator
@ -125,7 +108,7 @@ struct Sigmoid<Array<T, N> > {
Sigmoid<T> sigmoid_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
for (int i = 0; i < N; ++i) {
y[i] = sigmoid_op(rhs[i]);
}
@ -193,7 +176,20 @@ struct HardSwish {
minimum<T> mn;
maximum<T> mx;
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * (relu6 / T(6));
return x * relu6 / T(6);
}
};
template <>
struct HardSwish<float> {
using T = float;
CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
minimum<T> mn;
maximum<T> mx;
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * relu6 * 0.16666667f;
}
};
@ -213,6 +209,21 @@ struct HardSwish<Array<T, N> > {
}
};
template <int N>
struct HardSwish<Array<half_t, N> > {
using T = half_t;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
minimum<Array<T, N> > mn;
maximum<Array<T, N> > mx;
multiplies<Array<T, N> > mul;
plus<Array<T, N> > add;
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
}
};
//
// GELU function definitions implemented as described by
// Hendrycks, D., and Gimpel, K. in

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -162,6 +162,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -197,8 +199,15 @@ public:
minimum<ComputeFragment> min_accumulator;
maximum<ComputeFragment> max_accumulator;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
/// Clamping constant value
ElementCompute const kClampMax =
@ -235,7 +244,11 @@ public:
minimum<ComputeFragment> min_accumulator;
maximum<ComputeFragment> max_accumulator;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
/// Clamping constant value
ElementCompute const kClampMax =
@ -367,6 +380,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -399,8 +414,15 @@ public:
multiply_add<ComputeFragment> mul_add_accumulator;
// Float min-max
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
@ -430,7 +452,11 @@ public:
multiplies<ComputeFragment> mul_add_accumulator;
// Float min-max
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
@ -551,6 +577,8 @@ class FastLinearCombinationClamp {
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -586,10 +614,17 @@ class FastLinearCombinationClamp {
maximum<ComputeFragment> max_accumulator;
// Float min-max
intermediate =
mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
intermediate); // D = alpha * Accum + X
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate =
mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
intermediate); // D = alpha * Accum + X
}
/// Clamping constant value
ElementCompute const kClamp =
@ -624,7 +659,11 @@ class FastLinearCombinationClamp {
maximum<ComputeFragment> max_accumulator;
// Float min-max
intermediate = mul_accumulator(alpha_, converted_accumulator);
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator);
}
/// Clamping constant value
ElementCompute const kClamp =

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -51,11 +51,11 @@ template <
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
ElementCompute_, Scale, Round, true>;
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -54,6 +54,7 @@ template <
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
bool IsHeavy = false
>
@ -66,10 +67,11 @@ public:
static bool const kIsHeavy = IsHeavy;
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
@ -131,6 +133,12 @@ public:
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -152,19 +160,26 @@ public:
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
FragmentCompute intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ActivationFunctor<ComputeFragment> activation;
multiplies<FragmentCompute> mul_add_source;
multiply_add<FragmentCompute> mul_add_accumulator;
ActivationFunctor<FragmentCompute> activation;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
intermediate = activation(intermediate);
@ -182,16 +197,20 @@ public:
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
FragmentCompute intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
ActivationFunctor<ComputeFragment> activation;
multiplies<FragmentCompute> mul_add_accumulator;
ActivationFunctor<FragmentCompute> activation;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
intermediate = activation(intermediate);

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
/***************************************************************************************************
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -51,10 +51,11 @@ template <
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationHardSwish = LinearCombinationGeneric<HardSwish, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round>;
ElementCompute_, Scale, Round>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -156,6 +156,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -193,12 +195,15 @@ public:
multiply_add<FragmentCompute> mul_add_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::NoBetaScaling)
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
else
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
@ -225,7 +230,11 @@ public:
multiplies<FragmentCompute> mul_accumulator;
ReLu<FragmentCompute> relu;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
@ -269,8 +278,6 @@ public:
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -376,6 +383,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
@ -413,12 +422,15 @@ public:
multiply_add<FragmentCompute> mul_add_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::NoBetaScaling)
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
else
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
@ -459,7 +471,11 @@ public:
multiplies<FragmentCompute> mul_accumulator;
ReLu<FragmentCompute> relu;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
@ -512,22 +528,13 @@ public:
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
if (platform::is_same<ElementOutput, int32_t>::value ||
platform::is_same<ElementOutput, uint32_t>::value ||
platform::is_same<ElementOutput, int16_t>::value ||
platform::is_same<ElementOutput, uint16_t>::value ||
platform::is_same<ElementOutput, int8_t>::value ||
platform::is_same<ElementOutput, uint8_t>::value ||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value) {
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
}
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
@ -540,7 +547,6 @@ public:
return destination_converter(intermediate);
}
}
};
#endif // Conditional guards to enable partial specialization for packed integers

View File

@ -0,0 +1,535 @@
/***************************************************************************************************
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a relu operation used by epilogues.
This one only supports relu0 and tries to folding relu into other instructions. Thus,
serial splitk is not supported by this one. For example, relu can be folded into
hfma2/hmul2 for sm80+
*/
#pragma once
#include <cutlass/half.h>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
constexpr bool LinearCombinationRelu0IsHeavy() {
return false;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationRelu0 {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
using FragmentScaleBias = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta = ElementCompute(0)
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationRelu0(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
/// This is used for serial reduction which is not supported by Relu0
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
assert(k_partition == 0);
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_add_source;
multiply_add_relu0<FragmentCompute> mul_add_relu0_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
// Compute threshold optionally
intermediate = relu(intermediate);
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = relu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
/// Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentScaleBias const &scale,
FragmentScaleBias const &bias) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform per-channel scale and bias
FragmentCompute intermediate;
multiply_add<FragmentCompute> mul_add_accumulator;
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
else
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
ReLu<FragmentCompute> relu;
// Compute threshold optionally
intermediate = relu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Special handling for int types
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
FloatRoundStyle Round
>
class LinearCombinationRelu0 <ElementOutput_, Count, int, float, Scale, Round> {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = int;
using ElementCompute = float;
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
using FragmentScaleBias = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta = ElementCompute(0)
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationRelu0(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
/// This is used for serial reduction which is not supported by Relu0
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
assert(k_partition == 0);
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_add_source;
multiply_add<FragmentCompute> mul_add_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = relu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_accumulator;
ReLu<FragmentCompute> relu;
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = relu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
/// Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentScaleBias const &scale,
FragmentScaleBias const &bias) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform per-channel scale and bias
FragmentCompute intermediate;
multiply_add<FragmentCompute> mul_add_accumulator;
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
else
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
ReLu<FragmentCompute> relu;
// Compute threshold optionally
intermediate = relu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
};
#endif // Conditional guards to enable partial specialization for packed integers
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -51,10 +51,12 @@ template <
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationSigmoid = LinearCombinationGeneric<Sigmoid, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
ElementCompute_, Scale, Round, true>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -51,10 +51,11 @@ template <
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationSilu = LinearCombinationGeneric<SiLu, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
ElementCompute_, Scale, Round, true>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -36,13 +36,17 @@
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/conversion_op.h"

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -224,6 +224,40 @@ struct less {
}
};
template <typename T>
struct maximum {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (lhs < rhs ? rhs : lhs);
}
};
template <>
struct maximum<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fmaxf(lhs, rhs);
}
};
template <typename T>
struct minimum {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (rhs < lhs ? rhs : lhs);
}
};
template <>
struct minimum<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fminf(lhs, rhs);
}
};
/// Fused multiply-add
template <typename A, typename B = A, typename C = A>
struct multiply_add {
@ -233,6 +267,16 @@ struct multiply_add {
}
};
/// Fused multiply-add
template <typename A, typename B = A, typename C = A>
struct multiply_add_relu0 {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
maximum<C> mx;
return mx(C(a) * C(b) + c, C(0));
}
};
/// Fused multiply-add
template <typename T>
struct and_add {
@ -366,7 +410,6 @@ struct bit_or<Array<uint1b_t, N>> {
}
};
// Partial specializations for Arrays
template <int N>
struct bit_not<Array<uint1b_t, N>> {
@ -560,139 +603,6 @@ struct plus<Array<T, N>> {
return result;
}
};
template <typename T>
struct maximum {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (lhs < rhs ? rhs : lhs);
}
};
template <>
struct maximum<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fmaxf(lhs, rhs);
}
};
template <typename T, int N>
struct maximum<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], scalar);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, rhs[i]);
}
return result;
}
};
template <typename T>
struct minimum {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (rhs < lhs ? rhs : lhs);
}
};
template <>
struct minimum<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fminf(lhs, rhs);
}
};
template <typename T, int N>
struct minimum<Array<T, N>> {
CUTLASS_HOST_DEVICE
static T scalar_op(T const &lhs, T const &rhs) {
return (rhs < lhs ? rhs : lhs);
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], scalar);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, rhs[i]);
}
return result;
}
};
template <typename T, int N>
struct minus<Array<T, N>> {
@ -831,6 +741,102 @@ struct divides<Array<T, N>> {
}
};
template <typename T, int N>
struct maximum<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], scalar);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
Array<T, N> result;
maximum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, rhs[i]);
}
return result;
}
};
template <typename T, int N>
struct minimum<Array<T, N>> {
CUTLASS_HOST_DEVICE
static T scalar_op(T const &lhs, T const &rhs) {
return (rhs < lhs ? rhs : lhs);
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(lhs[i], scalar);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
Array<T, N> result;
minimum<T> scalar_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, rhs[i]);
}
return result;
}
};
template <typename T, int N>
struct negate<Array<T, N>> {
@ -897,6 +903,56 @@ struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
}
};
/// Fused multiply-add-relu0
template <typename T, int N>
struct multiply_add_relu0<Array<T, N>, Array<T, N>, Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
Array<T, N> result;
multiply_add<T> scalar_op;
maximum<T> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0));
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
Array<T, N> result;
multiply_add<T> scalar_op;
maximum<T> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0));
}
return result;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
Array<T, N> result;
multiply_add<T> scalar_op;
maximum<T> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0));
}
return result;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for Array<half_t, N> targeting SIMD instructions in device code.
@ -1536,6 +1592,413 @@ struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
}
};
/// Fused multiply-add-relu0
template <int N>
struct multiply_add_relu0<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(
Array<half_t, N> const &a,
Array<half_t, N> const &b,
Array<half_t, N> const &c) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
__half d_residual = __hfma_relu(
a_residual_ptr[N - 1],
b_residual_ptr[N - 1],
c_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
multiply_add<half_t> op;
maximum<half_t> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(op(a[i], b[i], c[i]), (half_t)0);
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(
half_t const &a,
Array<half_t, N> const &b,
Array<half_t, N> const &c) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]);
}
if (N % 2) {
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
__half d_residual = __hfma_relu(
reinterpret_cast<__half const &>(a),
b_residual_ptr[N - 1],
c_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
multiply_add<half_t> op;
maximum<half_t> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(op(a, b[i], c[i]), half_t(0));
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(
Array<half_t, N> const &a,
half_t const &b,
Array<half_t, N> const &c) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
__half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
__half d_residual = __hfma_relu(
a_residual_ptr[N - 1],
reinterpret_cast<__half const &>(b),
c_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
multiply_add<half_t> op;
maximum<half_t> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(op(a[i], b, c[i]), half_t(0));
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(
Array<half_t, N> const &a,
Array<half_t, N> const &b,
half_t const &c) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
__half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
__half d_residual = __hfma_relu(
a_residual_ptr[N - 1],
b_residual_ptr[N - 1],
reinterpret_cast<__half const &>(c));
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
multiply_add<half_t> op;
maximum<half_t> mx;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = mx(op(a[i], b[i], c));
}
#endif
return result;
}
};
template <int N>
struct minimum<Array<half_t, N>> {
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
__half d_residual = __hmin(
a_residual_ptr[N - 1],
b_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]);
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]);
}
if (N % 2) {
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
__half d_residual = __hmin(
reinterpret_cast<__half const &>(lhs),
b_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (rhs[i] < lhs ? rhs[i] : lhs);
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
__half d_residual = __hmin(
a_residual_ptr[N - 1],
reinterpret_cast<__half const &>(rhs));
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (rhs < lhs[i] ? rhs : lhs[i]);
}
#endif
return result;
}
};
template <int N>
struct maximum<Array<half_t, N>> {
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
__half d_residual = __hmax(
a_residual_ptr[N - 1],
b_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]);
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]);
}
if (N % 2) {
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
__half d_residual = __hmax(
reinterpret_cast<__half const &>(lhs),
b_residual_ptr[N - 1]);
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (lhs < rhs[i] ? rhs[i] : lhs);
}
#endif
return result;
}
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
Array<half_t, N> result;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
}
if (N % 2) {
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
__half d_residual = __hmax(
a_residual_ptr[N - 1],
reinterpret_cast<__half const &>(rhs));
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
}
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = (lhs[i] < rhs ? rhs : lhs[i]);
}
#endif
return result;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Fused multiply-add