CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -104,23 +104,15 @@ public:
/// Constant reference to element in tensor
using ConstReference = typename ConstTensorRef::Reference;
/// Note: Below is used to handle packing of subbyte elements
/// kBitsStoredVec : The bits of store vec that could be divisiable by the element
/// kElementsPerStoredVec : The number of elements could be stored in per store vec
/// kNumStoragePerStoredVec : How much storage(i.e. sizeof(element storage)) the store vec needs to consume.
/// Usually the element storage of subbyte is uint8_t.
/// Example
/// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t;
/// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t;
static constexpr int kBitsStoredVec = (sizeof_bits<Element>::value < 8) ? cutlass::lcm(sizeof_bits<Element>::value, 8) : sizeof_bits<Element>::value;
static constexpr int kElementsPerStoredVec = kBitsStoredVec / sizeof_bits<Element>::value;
static constexpr int kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8);
static_assert(kBitsStoredVec != 0, "kBitsStoredVec can not be zero");
static_assert(kElementsPerStoredVec != 0, "kElementsPerStoredVec can not be zero");
static_assert(kNumStoragePerStoredVec != 0, "kNumStoragePerStoredVec can not be zero");
private:
private:
using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
Element, uint8_t>>;
using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
//
// Data members
@ -133,13 +125,17 @@ public:
Layout layout_;
/// Host-side memory allocation
/// avoid the std::vector<bool> specialization
std::vector<std::conditional_t<std::is_same_v<Element,bool>, uint8_t, Element>> host_;
std::vector<StorageUnit> host_;
/// Device-side memory
device_memory::allocation<Element> device_;
device_memory::allocation<StorageUnit> device_;
public:
/// number of containers
size_t count_to_container_storage_unit_count(size_t count) {
return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
}
public:
//
// Device and Host Methods
//
@ -185,15 +181,15 @@ public:
device_.reset();
host_.clear();
count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec;
host_.resize(count);
size_t count_container = count_to_container_storage_unit_count(count);
host_.resize(count_container);
// Allocate memory
Element* device_memory = nullptr;
StorageUnit* device_memory = nullptr;
if (device_backed_) {
device_memory = device_memory::allocate<Element>(count);
device_memory = device_memory::allocate<StorageUnit>(count_container);
}
device_.reset(device_memory, device_backed_ ? count : 0);
device_.reset(device_memory, device_backed_ ? count_container : 0);
}
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
@ -229,8 +225,9 @@ public:
layout_ = layout;
LongIndex new_size = size_t(layout_.capacity(extent_));
LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
if (static_cast<decltype(host_.size())>(new_size) > host_.size()) {
if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
reserve(new_size, device_backed_);
}
}
@ -244,14 +241,14 @@ public:
resize(extent, Layout::packed(extent), device_backed_);
}
/// Returns the number of elements stored in the host tensor
/// Returns the logical number of elements stored in the host tensor
size_t size() const {
return host_.size() / kNumStoragePerStoredVec * kElementsPerStoredVec;
return layout_.capacity(extent_);
}
/// Returns the logical capacity based on extent and layout. May differ from size().
/// Returns the logical capacity in terms of number of elements. May be larger than the size().
LongIndex capacity() const {
return layout_.capacity(extent_);
return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
}
/// Gets pointer to host data
@ -277,10 +274,10 @@ public:
}
/// Gets pointer to device data
Element * device_data() { return device_.get(); }
Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
/// Gets pointer to device data
Element const * device_data() const { return device_.get(); }
Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
/// Gets pointer to device data with a pointer offset
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
@ -389,7 +386,7 @@ public:
void sync_host() {
if (device_backed()) {
device_memory::copy_to_host(
host_data(), device_data(), size());
host_.data(), device_.get(), device_.size());
}
}
@ -397,7 +394,7 @@ public:
void sync_device() {
if (device_backed()) {
device_memory::copy_to_device(
device_data(), host_data(), size());
device_.get(), host_.data(), host_.capacity());
}
}
@ -412,8 +409,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_to_host(
host_data(), ptr_device, count);
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -427,8 +425,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_device_to_device(
device_data(), ptr_device, count);
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -442,8 +441,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_to_device(
device_data(), ptr_host, count);
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -457,8 +457,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_host_to_host(
host_data(), ptr_host, count);
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -472,8 +473,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_to_host(
ptr_host, device_data(), count);
reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -487,8 +489,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_device_to_device(
ptr_device, device_data(), count);
reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -502,8 +505,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_to_device(
ptr_device, host_data(), count);
reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
}
/// Copy data from a caller-supplied device pointer into host memory.
@ -517,8 +521,9 @@ public:
else {
count = __NV_STD_MIN(capacity(), count);
}
size_t container_count = count_to_container_storage_unit_count(count);
device_memory::copy_host_to_host(
ptr_host, host_data(), count);
reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
}
};

View File

@ -458,6 +458,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
@ -497,6 +498,71 @@ make_cute_packed_stride(
return s_copy;
}
//
// Wgrad output tensor ((_1, s, r, t), k, _0)
//
// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
cute::array<IntT, 3> stride_ksc,
conv::Operator ConvOp) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
assert(stride_ksc[2] == 1);
auto s_copy = s;
cute::get<1,0>(s_copy) = stride_ksc[0];
cute::get<0,1>(s_copy) = stride_ksc[1];
return s_copy;
}
// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
cute::array<IntT, 4> stride_krsc,
conv::Operator ConvOp) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
assert(stride_krsc[3] == 1);
auto s_copy = s;
cute::get<1,0>(s_copy) = stride_krsc[0];
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
});
return s_copy;
}
// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
cute::array<IntT, 5> stride_ktrsc,
conv::Operator ConvOp) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
assert(stride_ktrsc[4] == 1);
auto s_copy = s;
cute::get<1,0>(s_copy) = stride_ktrsc[0];
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
});
return s_copy;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -443,10 +443,10 @@ struct RandomUniformFunc {
int int_scale_ = -1
):
seed(seed_),
range(static_cast<FloatType>(max_ - min)),
range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
max(static_cast<FloatType>(max_)),
int_scale(int_scale_) {
float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits
float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale);
}

View File

@ -125,7 +125,8 @@ template<
class ShapePadding,
class StrideTraversal,
class ShapeDilation,
class EpilogueFusionParams>
class EpilogueFusionParams
>
struct ConvReferenceImpl {
using ElementAcc = typename EpilogueFusionParams::ElementAcc;
using ElementC = typename EpilogueFusionParams::ElementC;
@ -145,7 +146,6 @@ struct ConvReferenceImpl {
NumericConverter<ElementOut, ElementCompute> output_converter;
EpilogueFusionParams& epi_fusion_params_;
TensorA const& tensor_a_;
TensorB const& tensor_b_;
TensorC const& tensor_c_;
@ -174,7 +174,8 @@ struct ConvReferenceImpl {
padding_(padding),
tstride_(tstride),
dilation_(dilation),
epi_fusion_params_(epi_fusion_params) {
epi_fusion_params_(epi_fusion_params)
{
static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
}
@ -211,7 +212,9 @@ private:
for (int32_t c = 0; c < C; ++c) {
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
if (detail::is_activation_in_bounds(tensor_a_, n, w, c)) {
accumulator += ElementAcc(tensor_a_(c, w, n) * tensor_b_(c, s, k));
auto a = tensor_a_(c, w, n);
auto b = tensor_b_(c, s, k);
accumulator += ElementAcc(a * b);
}
}
}
@ -256,7 +259,9 @@ private:
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c)) {
accumulator += ElementAcc(tensor_a_(c, w, h, n) * tensor_b_(c, s, r, k));
auto a = tensor_a_(c, w, h, n);
auto b = tensor_b_(c, s, r, k);
accumulator += ElementAcc(a * b);
}
}
}
@ -308,7 +313,9 @@ private:
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c)) {
accumulator += ElementAcc(tensor_a_(c, w, h, d, n) * tensor_b_(c, s, r, t, k));
auto a = tensor_a_(c, w, h, d, n);
auto b = tensor_b_(c, s, r, t, k);
accumulator += ElementAcc(a * b);
}
}
}
@ -516,9 +523,12 @@ private:
// Specialization for 1D wgrad kernel
void wgrad_reference(cute::Int<1> spatial_dims) {
int32_t N = size<2>(tensor_a_);
int32_t Q = size<1>(tensor_a_);
int32_t K = size<0>(tensor_a_);
int32_t N =
size<2>(tensor_a_);
int32_t Q =
size<1>(tensor_a_);
int32_t K =
size<0>(tensor_a_);
int32_t S = size<1>(tensor_d_);
int32_t C = size<0>(tensor_d_);
@ -536,8 +546,14 @@ private:
for (int32_t n = 0; n < N; ++n) {
for (int32_t q = 0; q < Q; ++q) {
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
if (detail::is_activation_in_bounds(tensor_b_, n, w, c)) {
accumulator += ElementAcc(tensor_b_(c, w, n) * tensor_a_(k, q, n));
bool is_in_bounds =
detail::is_activation_in_bounds(tensor_b_, n, w, c);
if (is_in_bounds) {
auto act =
tensor_b_(c, w, n);
auto xformed_act =
tensor_a_(k, q, n);
accumulator += ElementAcc(act * xformed_act);
}
}
}
@ -555,10 +571,14 @@ private:
// Specialization for 2D wgrad kernel
void wgrad_reference(cute::Int<2> spatial_dims) {
int32_t N = size<3>(tensor_a_);
int32_t P = size<2>(tensor_a_);
int32_t Q = size<1>(tensor_a_);
int32_t K = size<0>(tensor_a_);
int32_t N =
size<3>(tensor_a_);
int32_t P =
size<2>(tensor_a_);
int32_t Q =
size<1>(tensor_a_);
int32_t K =
size<0>(tensor_a_);
int32_t R = size<2>(tensor_d_);
int32_t S = size<1>(tensor_d_);
int32_t C = size<0>(tensor_d_);
@ -580,8 +600,14 @@ private:
for (int32_t q = 0; q < Q; ++q) {
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
if (detail::is_activation_in_bounds(tensor_b_, n, h, w, c)) {
accumulator += ElementAcc(tensor_b_(c, w, h, n) * tensor_a_(k, q, p, n));
bool is_in_bounds =
detail::is_activation_in_bounds(tensor_b_, n, h, w, c);
if (is_in_bounds) {
auto act =
tensor_b_(c, w, h, n);
auto xformed_act =
tensor_a_(k, q, p, n);
accumulator += ElementAcc(act * xformed_act);
}
}
}
@ -601,11 +627,16 @@ private:
// Specialization for 3D wgrad kernel
void wgrad_reference(cute::Int<3> spatial_dims) {
int32_t N = size<4>(tensor_a_);
int32_t Z = size<3>(tensor_a_);
int32_t P = size<2>(tensor_a_);
int32_t Q = size<1>(tensor_a_);
int32_t K = size<0>(tensor_a_);
int32_t N =
size<4>(tensor_a_);
int32_t Z =
size<3>(tensor_a_);
int32_t P =
size<2>(tensor_a_);
int32_t Q =
size<1>(tensor_a_);
int32_t K =
size<0>(tensor_a_);
int32_t T = size<3>(tensor_d_);
int32_t R = size<2>(tensor_d_);
int32_t S = size<1>(tensor_d_);
@ -631,8 +662,14 @@ private:
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
if (detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c)) {
accumulator += ElementAcc(tensor_b_(c, w, h, d, n) * tensor_a_(k, q, p, z, n));
bool is_in_bounds =
detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c);
if (is_in_bounds) {
auto act =
tensor_b_(c, w, h, d, n);
auto xformed_act =
tensor_a_(k, q, p, z, n);
accumulator += ElementAcc(act * xformed_act);
}
}
}

View File

@ -82,7 +82,6 @@ struct GettMainloopParams {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
@ -117,7 +116,6 @@ struct GettEpilogueParams {
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
@ -184,6 +182,8 @@ void gett_mainloop(
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
@ -254,6 +254,8 @@ void gett_epilogue(
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
@ -265,7 +267,6 @@ void gett_epilogue(
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
@ -300,7 +301,7 @@ void gett_epilogue(
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
NumericConverter<ElementBias, ElementCompute> dBias_converter;
// Epilogue operations
@ -417,6 +418,7 @@ void gett_epilogue(
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif