[hardswish] correct implmentation (#403)
* [hardswish] correct implmentation * seems working * hardswish fp32/fp16x2 optimization * [relu] half2 support * add relu0; add multiply_add_relu0; * cleanup Co-authored-by: Bing Xu <bingxu@fb.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -169,7 +169,6 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
@ -592,7 +591,7 @@ Result profile_convolution(Options const &options) {
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -61,17 +61,16 @@ struct ReLu {
|
||||
static const bool kIsHeavy=false;
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const & threshold, T value) const {
|
||||
if (value < threshold) {
|
||||
value = threshold;
|
||||
}
|
||||
return value;
|
||||
maximum<T> mx;
|
||||
|
||||
return mx(value, threshold);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value) const {
|
||||
if (value < T(0)) {
|
||||
value = T(0);
|
||||
}
|
||||
return value;
|
||||
maximum<T> mx;
|
||||
|
||||
return mx(value, T(0));
|
||||
}
|
||||
};
|
||||
|
||||
@ -80,32 +79,16 @@ struct ReLu<Array<T, N>> {
|
||||
static const bool kIsHeavy=false;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
||||
Array<T, N> result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T value = frag[i];
|
||||
if (value < threshold) {
|
||||
value = threshold;
|
||||
}
|
||||
result[i] = value;
|
||||
}
|
||||
return result;
|
||||
maximum<Array<T, N> > mx;
|
||||
|
||||
return mx(threshold, frag);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T value = frag[i];
|
||||
if (value < T(0)) {
|
||||
value = T(0);
|
||||
}
|
||||
result[i] = value;
|
||||
}
|
||||
return result;
|
||||
maximum<Array<T, N> > mx;
|
||||
return mx(frag, T(0));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Sigmoid operator
|
||||
@ -125,7 +108,7 @@ struct Sigmoid<Array<T, N> > {
|
||||
Sigmoid<T> sigmoid_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = sigmoid_op(rhs[i]);
|
||||
}
|
||||
|
||||
@ -193,7 +176,20 @@ struct HardSwish {
|
||||
minimum<T> mn;
|
||||
maximum<T> mx;
|
||||
T relu6 = mn(mx(x + T(3), T(0)), T(6));
|
||||
return x * (relu6 / T(6));
|
||||
return x * relu6 / T(6);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HardSwish<float> {
|
||||
using T = float;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &x) const {
|
||||
minimum<T> mn;
|
||||
maximum<T> mx;
|
||||
T relu6 = mn(mx(x + T(3), T(0)), T(6));
|
||||
return x * relu6 * 0.16666667f;
|
||||
}
|
||||
};
|
||||
|
||||
@ -213,6 +209,21 @@ struct HardSwish<Array<T, N> > {
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct HardSwish<Array<half_t, N> > {
|
||||
using T = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
minimum<Array<T, N> > mn;
|
||||
maximum<Array<T, N> > mx;
|
||||
multiplies<Array<T, N> > mul;
|
||||
plus<Array<T, N> > add;
|
||||
|
||||
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// GELU function definitions implemented as described by
|
||||
// Hendrycks, D., and Gimpel, K. in
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -162,6 +162,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -197,8 +199,15 @@ public:
|
||||
minimum<ComputeFragment> min_accumulator;
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClampMax =
|
||||
@ -235,7 +244,11 @@ public:
|
||||
minimum<ComputeFragment> min_accumulator;
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClampMax =
|
||||
@ -367,6 +380,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -399,8 +414,15 @@ public:
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
@ -430,7 +452,11 @@ public:
|
||||
multiplies<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
@ -551,6 +577,8 @@ class FastLinearCombinationClamp {
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -586,10 +614,17 @@ class FastLinearCombinationClamp {
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate =
|
||||
mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
|
||||
intermediate); // D = alpha * Accum + X
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate =
|
||||
mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
|
||||
intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp =
|
||||
@ -624,7 +659,11 @@ class FastLinearCombinationClamp {
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator);
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator);
|
||||
}
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp =
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -51,11 +51,11 @@ template <
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
|
||||
ElementCompute_, Round, true>;
|
||||
|
||||
ElementCompute_, Scale, Round, true>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -54,6 +54,7 @@ template <
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||
bool IsHeavy = false
|
||||
>
|
||||
@ -66,10 +67,11 @@ public:
|
||||
|
||||
static bool const kIsHeavy = IsHeavy;
|
||||
static int const kCount = Count;
|
||||
static const ScaleType::Kind kScale = Scale;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
@ -131,6 +133,12 @@ public:
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
if (Scale == ScaleType::NoBetaScaling) return true;
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -152,19 +160,26 @@ public:
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
ActivationFunctor<ComputeFragment> activation;
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
ActivationFunctor<FragmentCompute> activation;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
intermediate = activation(intermediate);
|
||||
|
||||
@ -182,16 +197,20 @@ public:
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_accumulator;
|
||||
ActivationFunctor<ComputeFragment> activation;
|
||||
multiplies<FragmentCompute> mul_add_accumulator;
|
||||
ActivationFunctor<FragmentCompute> activation;
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
intermediate = activation(intermediate);
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -51,10 +51,11 @@ template <
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using LinearCombinationHardSwish = LinearCombinationGeneric<HardSwish, ElementOutput_, Count, ElementAccumulator_,
|
||||
ElementCompute_, Round>;
|
||||
ElementCompute_, Scale, Round>;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -156,6 +156,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -193,12 +195,15 @@ public:
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling)
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
else
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
@ -225,7 +230,11 @@ public:
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
@ -269,8 +278,6 @@ public:
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -376,6 +383,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -413,12 +422,15 @@ public:
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling)
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
else
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
@ -459,7 +471,11 @@ public:
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
@ -512,22 +528,13 @@ public:
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
if (platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
platform::is_same<ElementOutput, uint16_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value) {
|
||||
if (platform::numeric_limits<ElementOutput>::is_integer) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
@ -540,7 +547,6 @@ public:
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
535
include/cutlass/epilogue/thread/linear_combination_relu0.h
Normal file
535
include/cutlass/epilogue/thread/linear_combination_relu0.h
Normal file
@ -0,0 +1,535 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a relu operation used by epilogues.
|
||||
This one only supports relu0 and tries to folding relu into other instructions. Thus,
|
||||
serial splitk is not supported by this one. For example, relu can be folded into
|
||||
hfma2/hmul2 for sm80+
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/half.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
||||
constexpr bool LinearCombinationRelu0IsHeavy() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationRelu0 {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
static const ScaleType::Kind kScale = Scale;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationRelu0(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
if (Scale == ScaleType::NoBetaScaling) return true;
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// This is used for serial reduction which is not supported by Relu0
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
assert(k_partition == 0);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add_relu0<FragmentCompute> mul_add_relu0_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
||||
/// Scale and Bias are from input Fragment
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentScaleBias const &scale,
|
||||
FragmentScaleBias const &bias) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform per-channel scale and bias
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
||||
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
||||
else
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
||||
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
/// Special handling for int types
|
||||
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
class LinearCombinationRelu0 <ElementOutput_, Count, int, float, Scale, Round> {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
|
||||
|
||||
static int const kCount = Count;
|
||||
static const ScaleType::Kind kScale = Scale;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationRelu0(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
if (Scale == ScaleType::NoBetaScaling) return true;
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// This is used for serial reduction which is not supported by Relu0
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
assert(k_partition == 0);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
|
||||
if (platform::numeric_limits<ElementOutput>::is_integer) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
} else {
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
|
||||
if (platform::numeric_limits<ElementOutput>::is_integer) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
} else {
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
||||
/// Scale and Bias are from input Fragment
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentScaleBias const &scale,
|
||||
FragmentScaleBias const &bias) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform per-channel scale and bias
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
||||
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
||||
else
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
||||
|
||||
ReLu<FragmentCompute> relu;
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(intermediate);
|
||||
|
||||
if (platform::numeric_limits<ElementOutput>::is_integer) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
} else {
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -51,10 +51,12 @@ template <
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using LinearCombinationSigmoid = LinearCombinationGeneric<Sigmoid, ElementOutput_, Count, ElementAccumulator_,
|
||||
ElementCompute_, Round, true>;
|
||||
ElementCompute_, Scale, Round, true>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -51,10 +51,11 @@ template <
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using LinearCombinationSilu = LinearCombinationGeneric<SiLu, ElementOutput_, Count, ElementAccumulator_,
|
||||
ElementCompute_, Round, true>;
|
||||
ElementCompute_, Scale, Round, true>;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -36,13 +36,17 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -224,6 +224,40 @@ struct less {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct maximum {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs < rhs ? rhs : lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct maximum<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &lhs, float const &rhs) const {
|
||||
return fmaxf(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimum {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &lhs, T const &rhs) const {
|
||||
return (rhs < lhs ? rhs : lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct minimum<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &lhs, float const &rhs) const {
|
||||
return fminf(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename A, typename B = A, typename C = A>
|
||||
struct multiply_add {
|
||||
@ -233,6 +267,16 @@ struct multiply_add {
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename A, typename B = A, typename C = A>
|
||||
struct multiply_add_relu0 {
|
||||
CUTLASS_HOST_DEVICE
|
||||
C operator()(A const &a, B const &b, C const &c) const {
|
||||
maximum<C> mx;
|
||||
return mx(C(a) * C(b) + c, C(0));
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename T>
|
||||
struct and_add {
|
||||
@ -366,7 +410,6 @@ struct bit_or<Array<uint1b_t, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Partial specializations for Arrays
|
||||
template <int N>
|
||||
struct bit_not<Array<uint1b_t, N>> {
|
||||
@ -560,139 +603,6 @@ struct plus<Array<T, N>> {
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct maximum {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs < rhs ? rhs : lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct maximum<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &lhs, float const &rhs) const {
|
||||
return fmaxf(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct maximum<Array<T, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], scalar);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimum {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &lhs, T const &rhs) const {
|
||||
return (rhs < lhs ? rhs : lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct minimum<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &lhs, float const &rhs) const {
|
||||
return fminf(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct minimum<Array<T, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static T scalar_op(T const &lhs, T const &rhs) {
|
||||
return (rhs < lhs ? rhs : lhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], scalar);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct minus<Array<T, N>> {
|
||||
|
||||
@ -831,6 +741,102 @@ struct divides<Array<T, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct maximum<Array<T, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], scalar);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
maximum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct minimum<Array<T, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static T scalar_op(T const &lhs, T const &rhs) {
|
||||
return (rhs < lhs ? rhs : lhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(lhs[i], scalar);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
||||
|
||||
Array<T, N> result;
|
||||
minimum<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct negate<Array<T, N>> {
|
||||
@ -897,6 +903,56 @@ struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add-relu0
|
||||
template <typename T, int N>
|
||||
struct multiply_add_relu0<Array<T, N>, Array<T, N>, Array<T, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
|
||||
Array<T, N> result;
|
||||
multiply_add<T> scalar_op;
|
||||
maximum<T> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
||||
|
||||
Array<T, N> result;
|
||||
multiply_add<T> scalar_op;
|
||||
maximum<T> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
|
||||
Array<T, N> result;
|
||||
multiply_add<T> scalar_op;
|
||||
maximum<T> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations for Array<half_t, N> targeting SIMD instructions in device code.
|
||||
@ -1536,6 +1592,413 @@ struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add-relu0
|
||||
template <int N>
|
||||
struct multiply_add_relu0<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(
|
||||
Array<half_t, N> const &a,
|
||||
Array<half_t, N> const &b,
|
||||
Array<half_t, N> const &c) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
||||
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
||||
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
||||
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
||||
|
||||
__half d_residual = __hfma_relu(
|
||||
a_residual_ptr[N - 1],
|
||||
b_residual_ptr[N - 1],
|
||||
c_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
multiply_add<half_t> op;
|
||||
maximum<half_t> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(op(a[i], b[i], c[i]), (half_t)0);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(
|
||||
half_t const &a,
|
||||
Array<half_t, N> const &b,
|
||||
Array<half_t, N> const &c) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
|
||||
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
||||
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
||||
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
||||
__half d_residual = __hfma_relu(
|
||||
reinterpret_cast<__half const &>(a),
|
||||
b_residual_ptr[N - 1],
|
||||
c_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
multiply_add<half_t> op;
|
||||
maximum<half_t> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(op(a, b[i], c[i]), half_t(0));
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(
|
||||
Array<half_t, N> const &a,
|
||||
half_t const &b,
|
||||
Array<half_t, N> const &c) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
||||
__half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
|
||||
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
||||
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
||||
|
||||
__half d_residual = __hfma_relu(
|
||||
a_residual_ptr[N - 1],
|
||||
reinterpret_cast<__half const &>(b),
|
||||
c_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
multiply_add<half_t> op;
|
||||
maximum<half_t> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(op(a[i], b, c[i]), half_t(0));
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(
|
||||
Array<half_t, N> const &a,
|
||||
Array<half_t, N> const &b,
|
||||
half_t const &c) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
||||
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
||||
__half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
||||
|
||||
__half d_residual = __hfma_relu(
|
||||
a_residual_ptr[N - 1],
|
||||
b_residual_ptr[N - 1],
|
||||
reinterpret_cast<__half const &>(c));
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
multiply_add<half_t> op;
|
||||
maximum<half_t> mx;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = mx(op(a[i], b[i], c));
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct minimum<Array<half_t, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
||||
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
||||
|
||||
__half d_residual = __hmin(
|
||||
a_residual_ptr[N - 1],
|
||||
b_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
||||
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
||||
|
||||
__half d_residual = __hmin(
|
||||
reinterpret_cast<__half const &>(lhs),
|
||||
b_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (rhs[i] < lhs ? rhs[i] : lhs);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
||||
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
||||
|
||||
__half d_residual = __hmin(
|
||||
a_residual_ptr[N - 1],
|
||||
reinterpret_cast<__half const &>(rhs));
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (rhs < lhs[i] ? rhs : lhs[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct maximum<Array<half_t, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
||||
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
||||
|
||||
__half d_residual = __hmax(
|
||||
a_residual_ptr[N - 1],
|
||||
b_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
||||
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
||||
|
||||
__half d_residual = __hmax(
|
||||
reinterpret_cast<__half const &>(lhs),
|
||||
b_residual_ptr[N - 1]);
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (lhs < rhs[i] ? rhs[i] : lhs);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
||||
Array<half_t, N> result;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
||||
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
||||
}
|
||||
|
||||
if (N % 2) {
|
||||
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
||||
|
||||
__half d_residual = __hmax(
|
||||
a_residual_ptr[N - 1],
|
||||
reinterpret_cast<__half const &>(rhs));
|
||||
|
||||
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = (lhs[i] < rhs ? rhs : lhs[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fused multiply-add
|
||||
|
||||
Reference in New Issue
Block a user