CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -181,7 +181,7 @@ public:
device_.reset();
host_.clear();
count = count / kElementsPerStoredVec * kNumStoragePerStoredVec;
count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec;
host_.resize(count);
// Allocate memory

View File

@ -45,6 +45,7 @@ namespace cutlass {
// Strides without batch mode
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Int<1>>
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<IntT>,
@ -55,6 +56,7 @@ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,
}
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<IntT>,
@ -69,6 +71,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,
// Strides with batch mode
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Int<1>, int64_t>
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<IntT>,
@ -86,6 +89,7 @@ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape
}
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT, int64_t>
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<IntT>,

View File

@ -257,16 +257,19 @@ void gett_epilogue(
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool IsScalingAndAmaxOutputNeeded =
std::is_same_v<ElementD, cutlass::float_e4m3_t> or
std::is_same_v<ElementD, cutlass::float_e5m2_t>;
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
std::is_same_v<ElementAux, cutlass::float_e4m3_t> or
std::is_same_v<ElementAux, cutlass::float_e5m2_t>;
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
constexpr bool IsReLUAuxNeeded =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> and
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
constexpr bool IsClamp =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
constexpr bool IsBackpropFusion =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
@ -276,7 +279,7 @@ void gett_epilogue(
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
NumericConverter<ElementCompute, ElementBias> bias_converter;
NumericConverter<ElementCompute, ElementAux> aux_source_converter;
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
@ -369,7 +372,12 @@ void gett_epilogue(
}
}
output = activation(output);
if constexpr (IsClamp) { // Treat Clamp as ReLU
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
}
else {
output = activation(output);
}
}
if constexpr (IsScalingAndAmaxOutputNeeded) {
@ -436,14 +444,14 @@ void Gemm3x(
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
Layout layout_A = make_layout_rank3(mainloop_params.A);
Layout layout_B = make_layout_rank3(mainloop_params.B);
Layout layout_C = make_layout_rank3(epilogue_params.C);
Layout layout_D = make_layout_rank3(epilogue_params.D);
Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
cute::Layout layout_A = make_layout_rank3(mainloop_params.A);
cute::Layout layout_B = make_layout_rank3(mainloop_params.B);
cute::Layout layout_C = make_layout_rank3(epilogue_params.C);
cute::Layout layout_D = make_layout_rank3(epilogue_params.D);
cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
auto TensorA = make_tensor(mainloop_params.A.data(), layout_A);
auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);