CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -181,7 +181,7 @@ public:
|
||||
device_.reset();
|
||||
host_.clear();
|
||||
|
||||
count = count / kElementsPerStoredVec * kNumStoragePerStoredVec;
|
||||
count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec;
|
||||
host_.resize(count);
|
||||
|
||||
// Allocate memory
|
||||
|
||||
@ -45,6 +45,7 @@ namespace cutlass {
|
||||
// Strides without batch mode
|
||||
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Int<1>>
|
||||
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
@ -55,6 +56,7 @@ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,
|
||||
}
|
||||
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
@ -69,6 +71,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,
|
||||
// Strides with batch mode
|
||||
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Int<1>, int64_t>
|
||||
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
@ -86,6 +89,7 @@ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape
|
||||
}
|
||||
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, IntT, int64_t>
|
||||
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
|
||||
@ -257,16 +257,19 @@ void gett_epilogue(
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
std::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
std::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
std::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
std::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> and
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
@ -276,7 +279,7 @@ void gett_epilogue(
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
@ -369,7 +372,12 @@ void gett_epilogue(
|
||||
}
|
||||
}
|
||||
|
||||
output = activation(output);
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
@ -436,14 +444,14 @@ void Gemm3x(
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
|
||||
if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
|
||||
Layout layout_A = make_layout_rank3(mainloop_params.A);
|
||||
Layout layout_B = make_layout_rank3(mainloop_params.B);
|
||||
Layout layout_C = make_layout_rank3(epilogue_params.C);
|
||||
Layout layout_D = make_layout_rank3(epilogue_params.D);
|
||||
Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
|
||||
Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
|
||||
Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
|
||||
Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
|
||||
cute::Layout layout_A = make_layout_rank3(mainloop_params.A);
|
||||
cute::Layout layout_B = make_layout_rank3(mainloop_params.B);
|
||||
cute::Layout layout_C = make_layout_rank3(epilogue_params.C);
|
||||
cute::Layout layout_D = make_layout_rank3(epilogue_params.D);
|
||||
cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
|
||||
cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
|
||||
cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
|
||||
cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
|
||||
|
||||
auto TensorA = make_tensor(mainloop_params.A.data(), layout_A);
|
||||
auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);
|
||||
|
||||
Reference in New Issue
Block a user