Gemm broadcast (#632)
* gemm_universal_with_broadcast, +2 sources.
* Revert "gemm_universal_with_broadcast, +2 sources."
This reverts commit fb063251f2.
* gemm_universal_with_broadcast separated version.
* Update copyright banner.
* update banner
This commit is contained in:
@ -53,6 +53,8 @@ namespace thread {
|
||||
|
||||
template <typename T>
|
||||
struct Identity {
|
||||
static const bool kIsHeavy=false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value) const {
|
||||
return value;
|
||||
@ -160,7 +162,7 @@ struct LeakyReLU {
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
@ -194,7 +196,7 @@ struct LeakyReLU<Array<T, N> > {
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
@ -464,7 +466,7 @@ struct HardSwish<Array<half_t, N> > {
|
||||
maximum<Array<T, N> > mx;
|
||||
multiplies<Array<T, N> > mul;
|
||||
plus<Array<T, N> > add;
|
||||
|
||||
|
||||
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
|
||||
}
|
||||
|
||||
@ -493,7 +495,7 @@ struct GELU {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -509,7 +511,7 @@ struct GELU<float> {
|
||||
return cutlass::constants::half<float>() * scalar *
|
||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -525,7 +527,7 @@ struct GELU<double> {
|
||||
return cutlass::constants::half<double>() * scalar *
|
||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -548,7 +550,7 @@ struct GELU<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -572,7 +574,7 @@ struct GELU_taylor {
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -618,7 +620,7 @@ struct GELU_taylor<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue functor specialized for residual blocks in deep neural network.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_=BinaryOp1_>
|
||||
class LinearCombinationResidualBlockV2 {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
using ElementZ = ElementOutput_;
|
||||
using ElementT = ElementZ;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales residual input
|
||||
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha(alpha), beta(beta) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor from Params
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationResidualBlockV2(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// The "source" tensor corresponds to the residual input
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const { return true; }
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual1, FragmentC const &residual2,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op1;
|
||||
BinaryOp2 binary_op2;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual1 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
|
||||
FragmentCompute tmp_residual2 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Should never be called
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||
FragmentCompute const &) const {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,177 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastTensorOpV2 {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcastV2<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding,
|
||||
Base::kFragmentsPerIteration
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastVoltaTensorOpV2 {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcastV2<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,847 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h"
|
||||
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This base class is meant to define the concept required of the
|
||||
/// EpilogueWithBroadcast::OutputOp
|
||||
template <
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreZ = true,
|
||||
bool StoreT = true
|
||||
>
|
||||
struct EpilogueWithBroadcastOpBaseV2 {
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
/// If true, the 'Z' tensor is stored
|
||||
static bool const kStoreZ = StoreZ;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = StoreT;
|
||||
|
||||
/// Parameters structure - required
|
||||
struct Params { };
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor from Params
|
||||
EpilogueWithBroadcastOpBaseV2(Params const ¶ms_) { }
|
||||
|
||||
/// Determine if the source is needed. May return false if
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) { }
|
||||
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C1,
|
||||
FragmentC const &frag_C2,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator with bias vector broadcast over columns.
|
||||
///
|
||||
/// Computes the following:
|
||||
///
|
||||
///
|
||||
/// Z, T = OutputOp(AB, C, Broadcast)
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreZ) {
|
||||
/// store(converted_u);
|
||||
/// }
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreT) {
|
||||
/// store(v);
|
||||
/// }
|
||||
///
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
||||
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
||||
typename ElementVector_, ///< Pointer to broadcast vector
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class EpilogueWithBroadcastV2 :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using TensorTileIterator = TensorTileIterator_;
|
||||
using ElementVector = ElementVector_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Compute data type produced by the output op
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
|
||||
/// Compute fragment
|
||||
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementCompute,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Data type of additional tensor
|
||||
using ElementTensor = typename TensorTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Tensor access type
|
||||
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
/// Shared memory allocation from epilogue base class
|
||||
using BaseSharedStorage = typename Base::SharedStorage;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// I'm not sure what I meant here.
|
||||
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
Shape::kN
|
||||
>;
|
||||
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
#if 0
|
||||
printf("BroadcastDetail {\n");
|
||||
printf(
|
||||
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
||||
kColumnsPerThread,
|
||||
kRowsPerThread,
|
||||
kThreadCount,
|
||||
kThreadsPerRow,
|
||||
kThreadRows,
|
||||
kThreadAccessesPerRow,
|
||||
StorageShape::kRow,
|
||||
StorageShape::kColumn,
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||
struct SharedStorage {
|
||||
union {
|
||||
BaseSharedStorage base;
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Thread index within the threadblock
|
||||
int thread_idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueWithBroadcastV2(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
|
||||
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
|
||||
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord(Shape::kM, Shape::kN),
|
||||
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||
MatrixCoord()) {
|
||||
|
||||
BroadcastFragment broadcast_fragment;
|
||||
|
||||
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
tensor_iterator);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator1,
|
||||
source_iterator2,
|
||||
tensor_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_(
|
||||
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
||||
) {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
ComputeFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
else if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_source_not_needed_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment1;
|
||||
source_fragment1.clear();
|
||||
typename OutputTileIterator::Fragment source_fragment2;
|
||||
source_fragment2.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator1.load(source_fragment1);
|
||||
++source_iterator1;
|
||||
|
||||
if (source_iterator2.enabled()) {
|
||||
source_iterator2.load(source_fragment2);
|
||||
++source_iterator2;
|
||||
}
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment1,
|
||||
source_fragment2,
|
||||
broadcast_fragment,
|
||||
source_iterator2.enabled());
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
typename OutputTileIterator::Fragment const &frag_C1,
|
||||
typename OutputTileIterator::Fragment const &frag_C2,
|
||||
BroadcastFragment const &frag_Broadcast,
|
||||
bool frag_C2_enabled) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C1_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C1);
|
||||
OutputAccessType const *frag_C2_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C2);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
if (frag_C2_enabled) {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_C2_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
} else {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
1023
include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h
Normal file
1023
include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h
Normal file
File diff suppressed because it is too large
Load Diff
384
include/cutlass/gemm/device/gemm_universal_with_broadcast.h
Normal file
384
include/cutlass/gemm/device/gemm_universal_with_broadcast.h
Normal file
@ -0,0 +1,384 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
The universal GEMM with a broadcast epilogue.
|
||||
Supports
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
|
||||
ElementC_, ElementAccumulator_, ElementAccumulator_,
|
||||
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone
|
||||
>
|
||||
class GemmUniversalWithBroadcast :
|
||||
public GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcastV2<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
> {
|
||||
|
||||
public:
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using Base = GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcastV2<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
>;
|
||||
|
||||
using Arguments = typename Base::Arguments;
|
||||
using GemmKernel = typename Base::GemmKernel;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB>
|
||||
class GemmUniversalWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
layout::ColumnMajor, // partially specialized on LayoutC
|
||||
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
|
||||
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||
Operator_, TransformA, TransformB> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using UnderlyingOperator = typename GemmUniversalWithBroadcast<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA,
|
||||
Operator,
|
||||
kTransformB,
|
||||
kTransformA
|
||||
>::Base;
|
||||
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalWithBroadcast() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
242
include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h
Normal file
242
include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h
Normal file
@ -0,0 +1,242 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultGemmWithBroadcastV2 {
|
||||
|
||||
using GemmBase = typename DefaultGemmUniversal<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_, ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator
|
||||
>::GemmKernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOpV2<
|
||||
typename GemmBase::Epilogue::Shape,
|
||||
typename GemmBase::Epilogue::WarpMmaOperator,
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Compose the GEMM kernel
|
||||
using GemmKernel = GemmWithFusedEpilogueV2<
|
||||
typename GemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization: ArchTag = cutlass::arch::Sm70
|
||||
///
|
||||
///
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
typename Enable
|
||||
>
|
||||
struct DefaultGemmWithBroadcastV2<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
cutlass::arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator,
|
||||
Enable
|
||||
> {
|
||||
|
||||
using GemmBase = typename DefaultGemmUniversal<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_, ElementAccumulator,
|
||||
OperatorClass,
|
||||
cutlass::arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator
|
||||
>::GemmKernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOpV2<
|
||||
typename GemmBase::Epilogue::Shape,
|
||||
typename GemmBase::Epilogue::WarpMmaOperator,
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Compose the GEMM kernel
|
||||
using GemmKernel = GemmWithFusedEpilogueV2<
|
||||
typename GemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
816
include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h
Normal file
816
include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h
Normal file
@ -0,0 +1,816 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Gemm kernel with fused reduction operation.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithFusedEpilogueV2 {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C1;
|
||||
void const * ptr_C2;
|
||||
void * ptr_D;
|
||||
|
||||
void * ptr_Vector;
|
||||
void * ptr_Tensor;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_C2;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Vector;
|
||||
int64_t batch_stride_Tensor;
|
||||
|
||||
typename LayoutA::Stride::Index lda;
|
||||
typename LayoutB::Stride::Index ldb;
|
||||
typename LayoutC::Stride::Index ldc1;
|
||||
typename LayoutC::Stride::Index ldc2;
|
||||
typename LayoutC::Stride::Index ldd;
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
typename LayoutC::Stride::Index ldt;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C1(nullptr), ptr_C2(nullptr), ptr_D(nullptr) { }
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C1,
|
||||
void const * ptr_C2,
|
||||
void * ptr_D,
|
||||
void * ptr_Vector,
|
||||
void * ptr_Tensor,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_C2,
|
||||
int64_t batch_stride_D,
|
||||
int64_t batch_stride_Vector,
|
||||
int64_t batch_stride_Tensor,
|
||||
typename LayoutA::Stride::Index lda,
|
||||
typename LayoutB::Stride::Index ldb,
|
||||
typename LayoutC::Stride::Index ldc1,
|
||||
typename LayoutC::Stride::Index ldc2,
|
||||
typename LayoutC::Stride::Index ldd,
|
||||
typename LayoutC::Stride::Index ldr,
|
||||
typename LayoutC::Stride::Index ldt
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D),
|
||||
ptr_Vector(ptr_Vector),
|
||||
ptr_Tensor(ptr_Tensor),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_C2(batch_stride_C2),
|
||||
batch_stride_D(batch_stride_D),
|
||||
batch_stride_Vector(batch_stride_Vector),
|
||||
batch_stride_Tensor(batch_stride_Tensor),
|
||||
lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
|
||||
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
|
||||
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
|
||||
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
void * ptr_Vector,
|
||||
void * ptr_Tensor,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
int64_t batch_stride_Vector,
|
||||
int64_t batch_stride_Tensor,
|
||||
typename LayoutA::Stride::Index lda,
|
||||
typename LayoutB::Stride::Index ldb,
|
||||
typename LayoutC::Stride::Index ldc,
|
||||
typename LayoutC::Stride::Index ldd,
|
||||
typename LayoutC::Stride::Index ldr,
|
||||
typename LayoutC::Stride::Index ldt
|
||||
): Arguments(
|
||||
mode, problem_size, batch_count, epilogue,
|
||||
ptr_A, ptr_B, ptr_C, nullptr, ptr_D, ptr_Vector, ptr_Tensor,
|
||||
batch_stride_A, batch_stride_B, batch_stride_C, 0, batch_stride_D,
|
||||
batch_stride_Vector, batch_stride_Tensor,
|
||||
lda, ldb, ldc, 0, ldd, ldr, ldt) {}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C2;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::TensorTileIterator::Params params_Tensor;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C1;
|
||||
void * ptr_C2;
|
||||
void * ptr_D;
|
||||
|
||||
void * ptr_Vector;
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
|
||||
void * ptr_Tensor;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_C2;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Vector;
|
||||
int64_t batch_stride_Tensor;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C1(0),
|
||||
params_C2(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C1(nullptr),
|
||||
ptr_C2(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Vector(nullptr),
|
||||
ldr(0),
|
||||
ptr_Tensor(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C1(0),
|
||||
batch_stride_C2(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_Vector(0),
|
||||
batch_stride_Tensor(0),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(args.lda),
|
||||
params_B(args.ldb),
|
||||
params_C1(args.ldc1),
|
||||
params_C2(args.ldc2),
|
||||
params_D(args.ldd),
|
||||
params_Tensor(args.ldt),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C1(const_cast<void *>(args.ptr_C1)),
|
||||
ptr_C2(const_cast<void *>(args.ptr_C2)),
|
||||
ptr_D(args.ptr_D),
|
||||
ptr_Vector(args.ptr_Vector),
|
||||
ldr(args.ldr),
|
||||
ptr_Tensor(args.ptr_Tensor),
|
||||
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C1(args.batch_stride_C1),
|
||||
batch_stride_C2(args.batch_stride_C2),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_Vector(args.batch_stride_Vector),
|
||||
batch_stride_Tensor(args.batch_stride_Tensor),
|
||||
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size);
|
||||
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
|
||||
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
|
||||
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
|
||||
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C1 = const_cast<void *>(args.ptr_C1);
|
||||
ptr_C2 = const_cast<void *>(args.ptr_C2);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_Vector = args.ptr_Vector;
|
||||
ldr = args.ldr;
|
||||
ptr_Tensor = args.ptr_Tensor;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C1 = args.batch_stride_C1;
|
||||
batch_stride_C2 = args.batch_stride_C2;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
batch_stride_Vector = args.batch_stride_Vector;
|
||||
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
||||
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
|
||||
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
|
||||
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithFusedEpilogueV2() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ptr_C1);
|
||||
ElementC *ptr_C2 = static_cast<ElementC *>(params.ptr_C2);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
typename Epilogue::ElementTensor *ptr_Tensor = static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
|
||||
|
||||
// Define the reduction output pointer and move to the appropriate place
|
||||
typename Epilogue::ElementVector *ptr_Vector =
|
||||
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
//
|
||||
// Special path when split-K not enabled.
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) {
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
ptr_C1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue::OutputTileIterator iterator_C2(
|
||||
params.params_C2,
|
||||
ptr_C2,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
ptr_D,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Additional tensor to load from
|
||||
typename Epilogue::TensorTileIterator tensor_iterator(
|
||||
params.params_Tensor,
|
||||
// Only the final block outputs Tensor
|
||||
ptr_Tensor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Move to appropriate location for this output tile
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr;
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op,
|
||||
ptr_Vector,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C1,
|
||||
iterator_C2,
|
||||
tensor_iterator,
|
||||
params.problem_size.mn(),
|
||||
threadblock_offset);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Slower path when split-K or batching is needed
|
||||
//
|
||||
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
if (ptr_C2) {
|
||||
ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2;
|
||||
}
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_C1 = static_cast<ElementC * const *>(params.ptr_C1)[threadblock_tile_offset.k()];
|
||||
if (ptr_C2) {
|
||||
ptr_C2 = static_cast<ElementC * const *>(params.ptr_C2)[threadblock_tile_offset.k()];
|
||||
}
|
||||
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor = static_cast<typename Epilogue::ElementTensor * const *>(params.ptr_Tensor)[threadblock_tile_offset.k()];
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector = static_cast<typename Epilogue::ElementVector * const *>(params.ptr_Vector)[threadblock_tile_offset.k()];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
ptr_C1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue::OutputTileIterator iterator_C2(
|
||||
params.params_C2,
|
||||
ptr_C2,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
ptr_D,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Additional tensor to load from
|
||||
typename Epilogue::TensorTileIterator tensor_iterator(
|
||||
params.params_Tensor,
|
||||
// Only the final block outputs Tensor
|
||||
((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) &&
|
||||
(params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1))
|
||||
? nullptr
|
||||
: ptr_Tensor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
// Move to appropriate location for this output tile
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr;
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op,
|
||||
// Only the final block uses Vector
|
||||
((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) &&
|
||||
(params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1))
|
||||
? nullptr
|
||||
: ptr_Vector,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C1,
|
||||
iterator_C2,
|
||||
tensor_iterator,
|
||||
params.problem_size.mn(),
|
||||
threadblock_offset);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Reference in New Issue
Block a user