Gemm broadcast (#632)

* gemm_universal_with_broadcast, +2 sources.

* Revert "gemm_universal_with_broadcast, +2 sources."

This reverts commit fb063251f2.

* gemm_universal_with_broadcast separated version.

* Update copyright banner.

* update banner
This commit is contained in:
Ying Zhang
2022-09-20 07:37:12 -07:00
committed by GitHub
parent f73374a1eb
commit a821280dc7
8 changed files with 3697 additions and 9 deletions

View File

@ -53,6 +53,8 @@ namespace thread {
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
@ -160,7 +162,7 @@ struct LeakyReLU {
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
@ -194,7 +196,7 @@ struct LeakyReLU<Array<T, N> > {
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
@ -464,7 +466,7 @@ struct HardSwish<Array<half_t, N> > {
maximum<Array<T, N> > mx;
multiplies<Array<T, N> > mul;
plus<Array<T, N> > add;
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
}
@ -493,7 +495,7 @@ struct GELU {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
@ -509,7 +511,7 @@ struct GELU<float> {
return cutlass::constants::half<float>() * scalar *
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_HOST_DEVICE
@ -525,7 +527,7 @@ struct GELU<double> {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
}
using Params = LinearCombinationGenericParams<double>;
CUTLASS_HOST_DEVICE
@ -548,7 +550,7 @@ struct GELU<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
@ -572,7 +574,7 @@ struct GELU_taylor {
}
using Params = LinearCombinationGenericParams<T>;
};
template <int N>
@ -618,7 +620,7 @@ struct GELU_taylor<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
};

View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue functor specialized for residual blocks in deep neural network.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_=BinaryOp1_>
class LinearCombinationResidualBlockV2 {
public:
using ElementOutput = ElementC_;
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementC, kElementsPerAccess>;
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
using ElementZ = ElementOutput_;
using ElementT = ElementZ;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;
static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales residual input
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
CUTLASS_HOST_DEVICE
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha, ElementCompute beta)
: alpha(alpha), beta(beta) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
};
private:
ElementCompute alpha_;
ElementCompute beta_;
bool skip_elementwise_;
public:
/// Constructor from Params
CUTLASS_HOST_DEVICE
LinearCombinationResidualBlockV2(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}
/// The "source" tensor corresponds to the residual input
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}
}
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
FragmentCompute z =
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual1, FragmentC const &residual2,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op1;
BinaryOp2 binary_op2;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual1 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
FragmentCompute tmp_residual2 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
FragmentCompute z =
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Should never be called
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
FragmentCompute const &) const {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,177 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess,
bool ScatterD = false
>
struct DefaultEpilogueWithBroadcastTensorOpV2 {
/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;
//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD
>;
//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementTensor
>;
/// Define the epilogue
using Epilogue = EpilogueWithBroadcastV2<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration
>;
};
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for VoltaTensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess
>
struct DefaultEpilogueWithBroadcastVoltaTensorOpV2 {
/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueVoltaTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;
//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementOutput
>;
//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementTensor
>;
/// Define the epilogue
using Epilogue = EpilogueWithBroadcastV2<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding
>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,847 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include <utility>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h"
#include "cutlass/util/index_sequence.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// This base class is meant to define the concept required of the
/// EpilogueWithBroadcast::OutputOp
template <
typename ElementC_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementZ_,
typename ElementT_,
int ElementsPerAccess,
bool StoreZ = true,
bool StoreT = true
>
struct EpilogueWithBroadcastOpBaseV2 {
using ElementOutput = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_;
using ElementT = ElementT_;
static int const kElementsPerAccess = ElementsPerAccess;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;
/// If true, the 'Z' tensor is stored
static bool const kStoreZ = StoreZ;
/// If true, the 'T' tensor is stored
static bool const kStoreT = StoreT;
/// Parameters structure - required
struct Params { };
//
// Methods
//
/// Constructor from Params
EpilogueWithBroadcastOpBaseV2(Params const &params_) { }
/// Determine if the source is needed. May return false if
bool is_source_needed() const {
return true;
}
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) { }
/// Applies the operation when is_source_needed() is true
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentC const &frag_C1,
FragmentC const &frag_C2,
FragmentCompute const &V) const {
}
/// Applies the operation when is_source_needed() is false
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentCompute const &V) const {
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator with bias vector broadcast over columns.
///
/// Computes the following:
///
///
/// Z, T = OutputOp(AB, C, Broadcast)
///
/// if (ElementwiseOp::kStoreZ) {
/// store(converted_u);
/// }
///
/// if (ElementwiseOp::kStoreT) {
/// store(v);
/// }
///
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
typename ElementVector_, ///< Pointer to broadcast vector
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
>
class EpilogueWithBroadcastV2 :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using TensorTileIterator = TensorTileIterator_;
using ElementVector = ElementVector_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Compute data type produced by the output op
using ElementCompute = typename OutputOp::ElementCompute;
/// Compute fragment
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementCompute,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Data type of additional tensor
using ElementTensor = typename TensorTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
/// Tensor access type
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
/// Shared memory allocation from epilogue base class
using BaseSharedStorage = typename Base::SharedStorage;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = kWarpSize * WarpCount::kCount;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// I'm not sure what I meant here.
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<
kThreadRows,
Shape::kN
>;
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("BroadcastDetail {\n");
printf(
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
kColumnsPerThread,
kRowsPerThread,
kThreadCount,
kThreadsPerRow,
kThreadRows,
kThreadAccessesPerRow,
StorageShape::kRow,
StorageShape::kColumn,
StorageShape::kCount
);
printf("};\n");
#endif
}
};
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
struct SharedStorage {
union {
BaseSharedStorage base;
};
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Thread index within the threadblock
int thread_idx_;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueWithBroadcastV2(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
thread_idx_(thread_idx)
{
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord()) {
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
if (!output_op.is_source_needed()) {
compute_source_not_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
tensor_iterator);
}
else {
compute_source_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
source_iterator1,
source_iterator2,
tensor_iterator);
}
}
private:
CUTLASS_DEVICE
void load_broadcast_fragment_(
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
ElementVector const * broadcast_ptr, ///< Broadcast vector
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
) {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
ComputeFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
// CUTLASS_PRAGMA_UNROLL
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations /
Base::kFragmentsPerIteration>>::push(iter,
accum_fragment_iterator,
this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
}
else if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_source_not_needed_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
broadcast_fragment);
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template<class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
typename OutputTileIterator::Fragment source_fragment1;
source_fragment1.clear();
typename OutputTileIterator::Fragment source_fragment2;
source_fragment2.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator1.load(source_fragment1);
++source_iterator1;
if (source_iterator2.enabled()) {
source_iterator2.load(source_fragment2);
++source_iterator2;
}
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment1,
source_fragment2,
broadcast_fragment,
source_iterator2.enabled());
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
typename OutputTileIterator::Fragment const &frag_C1,
typename OutputTileIterator::Fragment const &frag_C2,
BroadcastFragment const &frag_Broadcast,
bool frag_C2_enabled) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
OutputAccessType const *frag_C1_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C1);
OutputAccessType const *frag_C2_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C2);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
if (frag_C2_enabled) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_C2_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
} else {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
BroadcastFragment const &frag_Broadcast) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,384 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
The universal GEMM with a broadcast epilogue.
Supports
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
ElementC_, ElementAccumulator_, ElementAccumulator_,
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone
>
class GemmUniversalWithBroadcast :
public GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcastV2<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcastV2<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Operation performed by GEMM
typename Operator_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB>
class GemmUniversalWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
Operator_, TransformA, TransformB> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using UnderlyingOperator = typename GemmUniversalWithBroadcast<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
Operator,
kTransformB,
kTransformA
>::Base;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Argument structure
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalWithBroadcast() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,242 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable = void
>
struct DefaultGemmWithBroadcastV2 {
using GemmBase = typename DefaultGemmUniversal<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_, ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator
>::GemmKernel;
// Replace epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOpV2<
typename GemmBase::Epilogue::Shape,
typename GemmBase::Epilogue::WarpMmaOperator,
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
// Compose the GEMM kernel
using GemmKernel = GemmWithFusedEpilogueV2<
typename GemmBase::Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parital specialization: ArchTag = cutlass::arch::Sm70
///
///
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable
>
struct DefaultGemmWithBroadcastV2<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_,
ElementAccumulator,
OperatorClass,
cutlass::arch::Sm70,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
Enable
> {
using GemmBase = typename DefaultGemmUniversal<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_, ElementAccumulator,
OperatorClass,
cutlass::arch::Sm70,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator
>::GemmKernel;
// Replace epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOpV2<
typename GemmBase::Epilogue::Shape,
typename GemmBase::Epilogue::WarpMmaOperator,
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
// Compose the GEMM kernel
using GemmKernel = GemmWithFusedEpilogueV2<
typename GemmBase::Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,816 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Gemm kernel with fused reduction operation.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithFusedEpilogueV2 {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C1;
void const * ptr_C2;
void * ptr_D;
void * ptr_Vector;
void * ptr_Tensor;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C1;
int64_t batch_stride_C2;
int64_t batch_stride_D;
int64_t batch_stride_Vector;
int64_t batch_stride_Tensor;
typename LayoutA::Stride::Index lda;
typename LayoutB::Stride::Index ldb;
typename LayoutC::Stride::Index ldc1;
typename LayoutC::Stride::Index ldc2;
typename LayoutC::Stride::Index ldd;
typename LayoutC::Stride::Index ldr;
typename LayoutC::Stride::Index ldt;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C1(nullptr), ptr_C2(nullptr), ptr_D(nullptr) { }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C1,
void const * ptr_C2,
void * ptr_D,
void * ptr_Vector,
void * ptr_Tensor,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C1,
int64_t batch_stride_C2,
int64_t batch_stride_D,
int64_t batch_stride_Vector,
int64_t batch_stride_Tensor,
typename LayoutA::Stride::Index lda,
typename LayoutB::Stride::Index ldb,
typename LayoutC::Stride::Index ldc1,
typename LayoutC::Stride::Index ldc2,
typename LayoutC::Stride::Index ldd,
typename LayoutC::Stride::Index ldr,
typename LayoutC::Stride::Index ldt
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D),
ptr_Vector(ptr_Vector),
ptr_Tensor(ptr_Tensor),
batch_stride_A(batch_stride_A),
batch_stride_B(batch_stride_B),
batch_stride_C1(batch_stride_C1),
batch_stride_C2(batch_stride_C2),
batch_stride_D(batch_stride_D),
batch_stride_Vector(batch_stride_Vector),
batch_stride_Tensor(batch_stride_Tensor),
lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt)
{
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
void * ptr_Vector,
void * ptr_Tensor,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
int64_t batch_stride_Vector,
int64_t batch_stride_Tensor,
typename LayoutA::Stride::Index lda,
typename LayoutB::Stride::Index ldb,
typename LayoutC::Stride::Index ldc,
typename LayoutC::Stride::Index ldd,
typename LayoutC::Stride::Index ldr,
typename LayoutC::Stride::Index ldt
): Arguments(
mode, problem_size, batch_count, epilogue,
ptr_A, ptr_B, ptr_C, nullptr, ptr_D, ptr_Vector, ptr_Tensor,
batch_stride_A, batch_stride_B, batch_stride_C, 0, batch_stride_D,
batch_stride_Vector, batch_stride_Tensor,
lda, ldb, ldc, 0, ldd, ldr, ldt) {}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.batch_stride_A, args.batch_stride_B);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C1;
typename Epilogue::OutputTileIterator::Params params_C2;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::TensorTileIterator::Params params_Tensor;
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
void * ptr_C1;
void * ptr_C2;
void * ptr_D;
void * ptr_Vector;
typename LayoutC::Stride::Index ldr;
void * ptr_Tensor;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C1;
int64_t batch_stride_C2;
int64_t batch_stride_D;
int64_t batch_stride_Vector;
int64_t batch_stride_Tensor;
int *semaphore;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C1(0),
params_C2(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C1(nullptr),
ptr_C2(nullptr),
ptr_D(nullptr),
ptr_Vector(nullptr),
ldr(0),
ptr_Tensor(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C1(0),
batch_stride_C2(0),
batch_stride_D(0),
batch_stride_Vector(0),
batch_stride_Tensor(0),
semaphore(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(args.lda),
params_B(args.ldb),
params_C1(args.ldc1),
params_C2(args.ldc2),
params_D(args.ldd),
params_Tensor(args.ldt),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C1(const_cast<void *>(args.ptr_C1)),
ptr_C2(const_cast<void *>(args.ptr_C2)),
ptr_D(args.ptr_D),
ptr_Vector(args.ptr_Vector),
ldr(args.ldr),
ptr_Tensor(args.ptr_Tensor),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C1(args.batch_stride_C1),
batch_stride_C2(args.batch_stride_C2),
batch_stride_D(args.batch_stride_D),
batch_stride_Vector(args.batch_stride_Vector),
batch_stride_Tensor(args.batch_stride_Tensor),
semaphore(static_cast<int *>(workspace)) {
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size);
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C1 = const_cast<void *>(args.ptr_C1);
ptr_C2 = const_cast<void *>(args.ptr_C2);
ptr_D = args.ptr_D;
ptr_Vector = args.ptr_Vector;
ldr = args.ldr;
ptr_Tensor = args.ptr_Tensor;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C1 = args.batch_stride_C1;
batch_stride_C2 = args.batch_stride_C2;
batch_stride_D = args.batch_stride_D;
batch_stride_Vector = args.batch_stride_Vector;
batch_stride_Tensor = args.batch_stride_Tensor;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithFusedEpilogueV2() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ptr_C1);
ElementC *ptr_C2 = static_cast<ElementC *>(params.ptr_C2);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
typename Epilogue::ElementTensor *ptr_Tensor = static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
// Define the reduction output pointer and move to the appropriate place
typename Epilogue::ElementVector *ptr_Vector =
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
//
// Fetch pointers based on mode.
//
//
// Special path when split-K not enabled.
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) {
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
ptr_C1,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C2(
params.params_C2,
ptr_C2,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Additional tensor to load from
typename Epilogue::TensorTileIterator tensor_iterator(
params.params_Tensor,
// Only the final block outputs Tensor
ptr_Tensor,
params.problem_size.mn(),
thread_idx,
threadblock_offset);
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Move to appropriate location for this output tile
if (ptr_Vector) {
ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr;
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op,
ptr_Vector,
iterator_D,
accumulators,
iterator_C1,
iterator_C2,
tensor_iterator,
params.problem_size.mn(),
threadblock_offset);
return;
}
//
// Slower path when split-K or batching is needed
//
#if SPLIT_K_ENABLED
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
if (ptr_C2) {
ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2;
}
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
}
if (ptr_Vector) {
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;
}
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_C1 = static_cast<ElementC * const *>(params.ptr_C1)[threadblock_tile_offset.k()];
if (ptr_C2) {
ptr_C2 = static_cast<ElementC * const *>(params.ptr_C2)[threadblock_tile_offset.k()];
}
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
if (ptr_Tensor) {
ptr_Tensor = static_cast<typename Epilogue::ElementTensor * const *>(params.ptr_Tensor)[threadblock_tile_offset.k()];
}
if (ptr_Vector) {
ptr_Vector = static_cast<typename Epilogue::ElementVector * const *>(params.ptr_Vector)[threadblock_tile_offset.k()];
}
}
#endif
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
ptr_C1,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C2(
params.params_C2,
ptr_C2,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Additional tensor to load from
typename Epilogue::TensorTileIterator tensor_iterator(
params.params_Tensor,
// Only the final block outputs Tensor
((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) &&
(params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1))
? nullptr
: ptr_Tensor,
params.problem_size.mn(),
thread_idx,
threadblock_offset);
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
#if SPLIT_K_ENABLED
// Wait on the semaphore - this latency may have been covered by iterator construction
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
#endif
// Move to appropriate location for this output tile
if (ptr_Vector) {
ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr;
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op,
// Only the final block uses Vector
((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) &&
(params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1))
? nullptr
: ptr_Vector,
iterator_D,
accumulators,
iterator_C1,
iterator_C2,
tensor_iterator,
params.problem_size.mn(),
threadblock_offset);
//
// Release the semaphore
//
#if SPLIT_K_ENABLED
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////