Add epilogue functor for residual block fusion (#391)
* Add epilogue functor for residual block fusion * Do not run split-k tests when ActivationOp is not Identity * explain TestSplitK param * return early
This commit is contained in:
@ -68,8 +68,8 @@ struct ReLu {
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value) const {
|
||||
if (value < T()) {
|
||||
value = T();
|
||||
if (value < T(0)) {
|
||||
value = T(0);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -91,6 +91,21 @@ struct ReLu<Array<T, N>> {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T value = frag[i];
|
||||
if (value < T(0)) {
|
||||
value = T(0);
|
||||
}
|
||||
result[i] = value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Sigmoid operator
|
||||
@ -151,7 +166,8 @@ template <typename T>
|
||||
struct SiLu {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return scalar * Sigmoid<T>(scalar);
|
||||
Sigmoid<T> sigmoid;
|
||||
return scalar * sigmoid(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,163 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue functor specialized for residual blocks in deep neural network.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp_,
|
||||
template <typename T> class UnaryOp_>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
using ElementZ = ElementOutput_;
|
||||
using ElementT = ElementZ;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales residual input
|
||||
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha(alpha), beta(beta) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor from Params
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationResidualBlock(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// The "source" tensor corresponds to the residual input
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const { return true; }
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp binary_op;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Should never be called
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||
FragmentCompute const &) const {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Reference in New Issue
Block a user