@ -104,23 +104,15 @@ public:
|
||||
/// Constant reference to element in tensor
|
||||
using ConstReference = typename ConstTensorRef::Reference;
|
||||
|
||||
/// Note: Below is used to handle packing of subbyte elements
|
||||
/// kBitsStoredVec : The bits of store vec that could be divisiable by the element
|
||||
/// kElementsPerStoredVec : The number of elements could be stored in per store vec
|
||||
/// kNumStoragePerStoredVec : How much storage(i.e. sizeof(element storage)) the store vec needs to consume.
|
||||
/// Usually the element storage of subbyte is uint8_t.
|
||||
/// Example
|
||||
/// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t;
|
||||
/// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t;
|
||||
static constexpr int kBitsStoredVec = (sizeof_bits<Element>::value < 8) ? cutlass::lcm(sizeof_bits<Element>::value, 8) : sizeof_bits<Element>::value;
|
||||
static constexpr int kElementsPerStoredVec = kBitsStoredVec / sizeof_bits<Element>::value;
|
||||
static constexpr int kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8);
|
||||
|
||||
static_assert(kBitsStoredVec != 0, "kBitsStoredVec can not be zero");
|
||||
static_assert(kElementsPerStoredVec != 0, "kElementsPerStoredVec can not be zero");
|
||||
static_assert(kNumStoragePerStoredVec != 0, "kNumStoragePerStoredVec can not be zero");
|
||||
|
||||
private:
|
||||
private:
|
||||
using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
|
||||
typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
|
||||
Element, uint8_t>>;
|
||||
using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
|
||||
static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
|
||||
static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
|
||||
static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
|
||||
static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -133,13 +125,17 @@ public:
|
||||
Layout layout_;
|
||||
|
||||
/// Host-side memory allocation
|
||||
/// avoid the std::vector<bool> specialization
|
||||
std::vector<std::conditional_t<std::is_same_v<Element,bool>, uint8_t, Element>> host_;
|
||||
std::vector<StorageUnit> host_;
|
||||
|
||||
/// Device-side memory
|
||||
device_memory::allocation<Element> device_;
|
||||
device_memory::allocation<StorageUnit> device_;
|
||||
|
||||
public:
|
||||
/// number of containers
|
||||
size_t count_to_container_storage_unit_count(size_t count) {
|
||||
return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
|
||||
}
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
@ -185,15 +181,15 @@ public:
|
||||
device_.reset();
|
||||
host_.clear();
|
||||
|
||||
count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec;
|
||||
host_.resize(count);
|
||||
size_t count_container = count_to_container_storage_unit_count(count);
|
||||
host_.resize(count_container);
|
||||
|
||||
// Allocate memory
|
||||
Element* device_memory = nullptr;
|
||||
StorageUnit* device_memory = nullptr;
|
||||
if (device_backed_) {
|
||||
device_memory = device_memory::allocate<Element>(count);
|
||||
device_memory = device_memory::allocate<StorageUnit>(count_container);
|
||||
}
|
||||
device_.reset(device_memory, device_backed_ ? count : 0);
|
||||
device_.reset(device_memory, device_backed_ ? count_container : 0);
|
||||
}
|
||||
|
||||
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
||||
@ -229,8 +225,9 @@ public:
|
||||
layout_ = layout;
|
||||
|
||||
LongIndex new_size = size_t(layout_.capacity(extent_));
|
||||
LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
|
||||
|
||||
if (static_cast<decltype(host_.size())>(new_size) > host_.size()) {
|
||||
if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
|
||||
reserve(new_size, device_backed_);
|
||||
}
|
||||
}
|
||||
@ -244,14 +241,14 @@ public:
|
||||
resize(extent, Layout::packed(extent), device_backed_);
|
||||
}
|
||||
|
||||
/// Returns the number of elements stored in the host tensor
|
||||
/// Returns the logical number of elements stored in the host tensor
|
||||
size_t size() const {
|
||||
return host_.size() / kNumStoragePerStoredVec * kElementsPerStoredVec;
|
||||
return layout_.capacity(extent_);
|
||||
}
|
||||
|
||||
/// Returns the logical capacity based on extent and layout. May differ from size().
|
||||
/// Returns the logical capacity in terms of number of elements. May be larger than the size().
|
||||
LongIndex capacity() const {
|
||||
return layout_.capacity(extent_);
|
||||
return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
|
||||
}
|
||||
|
||||
/// Gets pointer to host data
|
||||
@ -277,10 +274,10 @@ public:
|
||||
}
|
||||
|
||||
/// Gets pointer to device data
|
||||
Element * device_data() { return device_.get(); }
|
||||
Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
|
||||
|
||||
/// Gets pointer to device data
|
||||
Element const * device_data() const { return device_.get(); }
|
||||
Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
|
||||
|
||||
/// Gets pointer to device data with a pointer offset
|
||||
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
|
||||
@ -389,7 +386,7 @@ public:
|
||||
void sync_host() {
|
||||
if (device_backed()) {
|
||||
device_memory::copy_to_host(
|
||||
host_data(), device_data(), size());
|
||||
host_.data(), device_.get(), device_.size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -397,7 +394,7 @@ public:
|
||||
void sync_device() {
|
||||
if (device_backed()) {
|
||||
device_memory::copy_to_device(
|
||||
device_data(), host_data(), size());
|
||||
device_.get(), host_.data(), host_.capacity());
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,8 +409,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_to_host(
|
||||
host_data(), ptr_device, count);
|
||||
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -427,8 +425,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_device_to_device(
|
||||
device_data(), ptr_device, count);
|
||||
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -442,8 +441,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_to_device(
|
||||
device_data(), ptr_host, count);
|
||||
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -457,8 +457,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_host_to_host(
|
||||
host_data(), ptr_host, count);
|
||||
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -472,8 +473,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_to_host(
|
||||
ptr_host, device_data(), count);
|
||||
reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -487,8 +489,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_device_to_device(
|
||||
ptr_device, device_data(), count);
|
||||
reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -502,8 +505,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_to_device(
|
||||
ptr_device, host_data(), count);
|
||||
reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer into host memory.
|
||||
@ -517,8 +521,9 @@ public:
|
||||
else {
|
||||
count = __NV_STD_MIN(capacity(), count);
|
||||
}
|
||||
size_t container_count = count_to_container_storage_unit_count(count);
|
||||
device_memory::copy_host_to_host(
|
||||
ptr_host, host_data(), count);
|
||||
reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -458,6 +458,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
|
||||
@ -497,6 +498,71 @@ make_cute_packed_stride(
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Wgrad output tensor ((_1, s, r, t), k, _0)
|
||||
//
|
||||
|
||||
// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
||||
cute::array<IntT, 3> stride_ksc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ksc[2] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1,0>(s_copy) = stride_ksc[0];
|
||||
cute::get<0,1>(s_copy) = stride_ksc[1];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
||||
cute::array<IntT, 4> stride_krsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_krsc[3] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1,0>(s_copy) = stride_krsc[0];
|
||||
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
||||
cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
||||
cute::array<IntT, 5> stride_ktrsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ktrsc[4] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1,0>(s_copy) = stride_ktrsc[0];
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -443,10 +443,10 @@ struct RandomUniformFunc {
|
||||
int int_scale_ = -1
|
||||
):
|
||||
seed(seed_),
|
||||
range(static_cast<FloatType>(max_ - min)),
|
||||
range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
|
||||
max(static_cast<FloatType>(max_)),
|
||||
int_scale(int_scale_) {
|
||||
|
||||
|
||||
float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits
|
||||
float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale);
|
||||
}
|
||||
|
||||
@ -125,7 +125,8 @@ template<
|
||||
class ShapePadding,
|
||||
class StrideTraversal,
|
||||
class ShapeDilation,
|
||||
class EpilogueFusionParams>
|
||||
class EpilogueFusionParams
|
||||
>
|
||||
struct ConvReferenceImpl {
|
||||
using ElementAcc = typename EpilogueFusionParams::ElementAcc;
|
||||
using ElementC = typename EpilogueFusionParams::ElementC;
|
||||
@ -145,7 +146,6 @@ struct ConvReferenceImpl {
|
||||
NumericConverter<ElementOut, ElementCompute> output_converter;
|
||||
|
||||
EpilogueFusionParams& epi_fusion_params_;
|
||||
|
||||
TensorA const& tensor_a_;
|
||||
TensorB const& tensor_b_;
|
||||
TensorC const& tensor_c_;
|
||||
@ -174,7 +174,8 @@ struct ConvReferenceImpl {
|
||||
padding_(padding),
|
||||
tstride_(tstride),
|
||||
dilation_(dilation),
|
||||
epi_fusion_params_(epi_fusion_params) {
|
||||
epi_fusion_params_(epi_fusion_params)
|
||||
{
|
||||
static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
|
||||
static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
|
||||
}
|
||||
@ -211,7 +212,9 @@ private:
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, n) * tensor_b_(c, s, k));
|
||||
auto a = tensor_a_(c, w, n);
|
||||
auto b = tensor_b_(c, s, k);
|
||||
accumulator += ElementAcc(a * b);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -256,7 +259,9 @@ private:
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, h, n) * tensor_b_(c, s, r, k));
|
||||
auto a = tensor_a_(c, w, h, n);
|
||||
auto b = tensor_b_(c, s, r, k);
|
||||
accumulator += ElementAcc(a * b);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -308,7 +313,9 @@ private:
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, h, d, n) * tensor_b_(c, s, r, t, k));
|
||||
auto a = tensor_a_(c, w, h, d, n);
|
||||
auto b = tensor_b_(c, s, r, t, k);
|
||||
accumulator += ElementAcc(a * b);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -516,9 +523,12 @@ private:
|
||||
|
||||
// Specialization for 1D wgrad kernel
|
||||
void wgrad_reference(cute::Int<1> spatial_dims) {
|
||||
int32_t N = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t N =
|
||||
size<2>(tensor_a_);
|
||||
int32_t Q =
|
||||
size<1>(tensor_a_);
|
||||
int32_t K =
|
||||
size<0>(tensor_a_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
|
||||
@ -536,8 +546,14 @@ private:
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, n) * tensor_a_(k, q, n));
|
||||
bool is_in_bounds =
|
||||
detail::is_activation_in_bounds(tensor_b_, n, w, c);
|
||||
if (is_in_bounds) {
|
||||
auto act =
|
||||
tensor_b_(c, w, n);
|
||||
auto xformed_act =
|
||||
tensor_a_(k, q, n);
|
||||
accumulator += ElementAcc(act * xformed_act);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -555,10 +571,14 @@ private:
|
||||
|
||||
// Specialization for 2D wgrad kernel
|
||||
void wgrad_reference(cute::Int<2> spatial_dims) {
|
||||
int32_t N = size<3>(tensor_a_);
|
||||
int32_t P = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t N =
|
||||
size<3>(tensor_a_);
|
||||
int32_t P =
|
||||
size<2>(tensor_a_);
|
||||
int32_t Q =
|
||||
size<1>(tensor_a_);
|
||||
int32_t K =
|
||||
size<0>(tensor_a_);
|
||||
int32_t R = size<2>(tensor_d_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
@ -580,8 +600,14 @@ private:
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, h, n) * tensor_a_(k, q, p, n));
|
||||
bool is_in_bounds =
|
||||
detail::is_activation_in_bounds(tensor_b_, n, h, w, c);
|
||||
if (is_in_bounds) {
|
||||
auto act =
|
||||
tensor_b_(c, w, h, n);
|
||||
auto xformed_act =
|
||||
tensor_a_(k, q, p, n);
|
||||
accumulator += ElementAcc(act * xformed_act);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -601,11 +627,16 @@ private:
|
||||
|
||||
// Specialization for 3D wgrad kernel
|
||||
void wgrad_reference(cute::Int<3> spatial_dims) {
|
||||
int32_t N = size<4>(tensor_a_);
|
||||
int32_t Z = size<3>(tensor_a_);
|
||||
int32_t P = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t N =
|
||||
size<4>(tensor_a_);
|
||||
int32_t Z =
|
||||
size<3>(tensor_a_);
|
||||
int32_t P =
|
||||
size<2>(tensor_a_);
|
||||
int32_t Q =
|
||||
size<1>(tensor_a_);
|
||||
int32_t K =
|
||||
size<0>(tensor_a_);
|
||||
int32_t T = size<3>(tensor_d_);
|
||||
int32_t R = size<2>(tensor_d_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
@ -631,8 +662,14 @@ private:
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, h, d, n) * tensor_a_(k, q, p, z, n));
|
||||
bool is_in_bounds =
|
||||
detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c);
|
||||
if (is_in_bounds) {
|
||||
auto act =
|
||||
tensor_b_(c, w, h, d, n);
|
||||
auto xformed_act =
|
||||
tensor_a_(k, q, p, z, n);
|
||||
accumulator += ElementAcc(act * xformed_act);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +82,6 @@ struct GettMainloopParams {
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
@ -117,7 +116,6 @@ struct GettEpilogueParams {
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
@ -184,6 +182,8 @@ void gett_mainloop(
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
@ -254,6 +254,8 @@ void gett_epilogue(
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
@ -265,7 +267,6 @@ void gett_epilogue(
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
@ -300,7 +301,7 @@ void gett_epilogue(
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
@ -417,6 +418,7 @@ void gett_epilogue(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user