Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -109,7 +109,7 @@ struct GemmFunctionalKey {
|
||||
|
||||
inline
|
||||
bool operator==(GemmFunctionalKey const &rhs) const {
|
||||
return
|
||||
return
|
||||
(provider == rhs.provider) &&
|
||||
(gemm_kind == rhs.gemm_kind) &&
|
||||
(element_compute == rhs.element_compute) &&
|
||||
@ -165,7 +165,7 @@ struct GemmFunctionalKeyHasher {
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8 - shl));
|
||||
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
|
||||
}
|
||||
|
||||
inline
|
||||
@ -173,8 +173,8 @@ struct GemmFunctionalKeyHasher {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.gemm_kind)), 2) ^
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.gemm_kind)), 2) ^
|
||||
rotl(hash(int(key.element_compute)), 3) ^
|
||||
rotl(hash(int(key.element_scalar)), 4) ^
|
||||
rotl(hash(int(key.element_A)), 5) ^
|
||||
@ -207,7 +207,7 @@ struct GemmPreferenceKey {
|
||||
GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { }
|
||||
|
||||
bool operator<(GemmPreferenceKey const &rhs) const {
|
||||
return (compute_capability < rhs.compute_capability) ||
|
||||
return (compute_capability < rhs.compute_capability) ||
|
||||
((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment));
|
||||
}
|
||||
|
||||
@ -288,9 +288,9 @@ struct ConvFunctionalKey {
|
||||
layout_C(layout_C),
|
||||
element_accumulator(element_accumulator),
|
||||
element_compute(element_compute)
|
||||
{ }
|
||||
{ }
|
||||
|
||||
inline
|
||||
inline
|
||||
bool operator==(ConvFunctionalKey const &rhs) const {
|
||||
return
|
||||
(provider == rhs.provider) &&
|
||||
@ -305,7 +305,7 @@ struct ConvFunctionalKey {
|
||||
(element_compute == rhs.element_compute);
|
||||
}
|
||||
|
||||
inline
|
||||
inline
|
||||
bool operator!=(ConvFunctionalKey const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
@ -325,7 +325,7 @@ std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctio
|
||||
<< "element_accumulator: " << to_string(key.element_accumulator) << std::endl
|
||||
<< "element_compute: " << to_string(key.element_compute) << std::endl
|
||||
<< "}";
|
||||
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -335,14 +335,14 @@ struct ConvFunctionalKeyHasher {
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8 - shl));
|
||||
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
|
||||
}
|
||||
|
||||
inline
|
||||
size_t operator()(ConvFunctionalKey const &key) const {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
return
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.conv_kind)), 2) ^
|
||||
rotl(hash(int(key.element_A)), 3) ^
|
||||
@ -370,11 +370,11 @@ struct ConvPreferenceKey {
|
||||
|
||||
ConvPreferenceKey(): compute_capability(), iterator_algorithm() { }
|
||||
|
||||
ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm):
|
||||
ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm):
|
||||
compute_capability(cc), iterator_algorithm(iterator_algorithm) { }
|
||||
|
||||
bool operator<(ConvPreferenceKey const &rhs) const {
|
||||
return (compute_capability < rhs.compute_capability) ||
|
||||
return (compute_capability < rhs.compute_capability) ||
|
||||
((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm));
|
||||
}
|
||||
|
||||
@ -433,9 +433,9 @@ struct ReductionFunctionalKey {
|
||||
element_compute(element_compute),
|
||||
reduce_math_op(reduce_math_op),
|
||||
epilogue_math_op(epilogue_math_op)
|
||||
{ }
|
||||
{ }
|
||||
|
||||
inline
|
||||
inline
|
||||
bool operator==(ReductionFunctionalKey const &rhs) const {
|
||||
return
|
||||
(provider == rhs.provider) &&
|
||||
@ -447,7 +447,7 @@ struct ReductionFunctionalKey {
|
||||
(epilogue_math_op == rhs.epilogue_math_op);
|
||||
}
|
||||
|
||||
inline
|
||||
inline
|
||||
bool operator!=(ReductionFunctionalKey const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
@ -459,14 +459,14 @@ struct ReductionFunctionalKeyHasher {
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8 - shl));
|
||||
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
|
||||
}
|
||||
|
||||
inline
|
||||
size_t operator()(ReductionFunctionalKey const &key) const {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
return
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.element_workspace)), 2) ^
|
||||
rotl(hash(int(key.element_accumulator)), 3) ^
|
||||
@ -505,19 +505,19 @@ using ReductionOperationFunctionalMap = std::unordered_map<
|
||||
class OperationTable {
|
||||
public:
|
||||
|
||||
/// Map of all operations of type kGemm
|
||||
/// Map of all operations of type kGemm
|
||||
// provider (kCUTLASS)
|
||||
GemmOperationFunctionalMap gemm_operations;
|
||||
|
||||
/// Map of all operations of type kConv2d
|
||||
/// Map of all operations of type kConv2d
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
ConvOperationFunctionalMap conv2d_operations;
|
||||
|
||||
/// Map of all operations of type kConv3d
|
||||
/// Map of all operations of type kConv3d
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
ConvOperationFunctionalMap conv3d_operations;
|
||||
|
||||
/// Map of all operations of type kConv2d
|
||||
/// Map of all operations of type kConv2d
|
||||
// provider (kCUTLASS)
|
||||
ReductionOperationFunctionalMap reduction_operations;
|
||||
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include <unordered_map>
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -271,7 +272,6 @@ public:
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement(
|
||||
void const *configuration_ptr, void const *arguments_ptr) const override {
|
||||
|
||||
GemmUniversalConfiguration const *configuration =
|
||||
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
|
||||
GemmUniversalArguments const *arguments =
|
||||
@ -289,7 +289,6 @@ public:
|
||||
configuration->problem_size.n(),
|
||||
configuration->problem_size.k(),
|
||||
configuration->batch_count);
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
|
||||
@ -152,6 +152,7 @@ template <> struct NumericTypeMap<cutlass::tfloat32_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kTF32;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct MathOperationMap {
|
||||
|
||||
@ -422,6 +422,8 @@ Status from_string<Status>(std::string const &str) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
|
||||
@ -238,7 +238,9 @@ protected:
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
ProblemSpace::Problem const &problem,
|
||||
cutlass::library::NumericTypeID element_A,
|
||||
cutlass::library::NumericTypeID element_B);
|
||||
|
||||
/// Method to profile a CUTLASS Operation
|
||||
Status profile_cutlass_(
|
||||
|
||||
@ -746,7 +746,13 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
}
|
||||
#endif // #if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem);
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
|
||||
cutlass::library::NumericTypeID element_A = gemm_desc.A.element;
|
||||
cutlass::library::NumericTypeID element_B = gemm_desc.B.element;
|
||||
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B);
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
// verification providers which are supported
|
||||
@ -912,8 +918,10 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
ProblemSpace::Problem const &problem,
|
||||
cutlass::library::NumericTypeID element_A,
|
||||
cutlass::library::NumericTypeID element_B)
|
||||
{
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
@ -976,13 +984,13 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
|
||||
problem_.alpha.data(),
|
||||
|
||||
gemm_desc.A.element,
|
||||
element_A,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
ptr_A,
|
||||
int(gemm_workspace_.configuration.lda),
|
||||
|
||||
gemm_desc.B.element,
|
||||
element_B,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
ptr_B,
|
||||
@ -1010,7 +1018,6 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
results_.back().verification_map[provider] = Disposition::kNotRun;
|
||||
continue;
|
||||
}
|
||||
|
||||
results_.back().status = status;
|
||||
|
||||
if (provider == library::Provider::kReferenceHost) {
|
||||
|
||||
@ -62,7 +62,6 @@ template <typename EngineType, typename LayoutType>
|
||||
matrix_inf_norm_result
|
||||
matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
|
||||
{
|
||||
using std::abs;
|
||||
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
||||
using element_type = typename EngineType::value_type;
|
||||
|
||||
@ -74,14 +73,25 @@ matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
|
||||
const int64_t num_rows = cute::size<0>(host_matrix);
|
||||
const int64_t num_cols = cute::size<1>(host_matrix);
|
||||
|
||||
for(int64_t i = 0; i < num_rows; ++i) {
|
||||
auto abs_fn = [] (element_type A_ij) {
|
||||
if constexpr (not std::is_unsigned_v<element_type>) {
|
||||
using std::abs;
|
||||
return abs(A_ij);
|
||||
}
|
||||
else {
|
||||
return A_ij;
|
||||
}
|
||||
};
|
||||
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
error_type row_abs_sum = 0.0;
|
||||
for(int64_t j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += abs(host_matrix(i, j));
|
||||
row_abs_sum += abs_fn(host_matrix(i, j));
|
||||
}
|
||||
if(std::isnan(row_abs_sum)) {
|
||||
if (std::isnan(row_abs_sum)) {
|
||||
found_nan = true;
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
||||
}
|
||||
}
|
||||
@ -95,10 +105,19 @@ matrix_inf_norm_result
|
||||
matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
|
||||
cute::Tensor<EngineType, LayoutType> const& Y)
|
||||
{
|
||||
using std::abs;
|
||||
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
||||
using element_type = typename EngineType::value_type;
|
||||
|
||||
auto abs_fn = [] (element_type A_ij) {
|
||||
if constexpr (not std::is_unsigned_v<element_type>) {
|
||||
using std::abs;
|
||||
return abs(A_ij);
|
||||
}
|
||||
else {
|
||||
return A_ij;
|
||||
}
|
||||
};
|
||||
|
||||
assert(cute::size<0>(X) == cute::size<0>(Y));
|
||||
assert(cute::size<1>(X) == cute::size<1>(Y));
|
||||
|
||||
@ -110,15 +129,16 @@ matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
|
||||
error_type inf_norm = 0.0;
|
||||
bool found_nan = false;
|
||||
|
||||
for(int64_t i = 0; i < num_rows; ++i) {
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
error_type row_abs_sum = 0.0;
|
||||
for(int64_t j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += error_type(abs(element_type(X(i,j)) -
|
||||
element_type(Y(i,j))));
|
||||
for (int64_t j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += error_type(abs_fn(element_type(X(i,j)) -
|
||||
element_type(Y(i,j))));
|
||||
}
|
||||
if(std::isnan(row_abs_sum)) {
|
||||
if (std::isnan(row_abs_sum)) {
|
||||
found_nan = true;
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
||||
}
|
||||
}
|
||||
@ -130,7 +150,7 @@ template <typename EngineType_A, typename LayoutType_A,
|
||||
typename EngineType_B, typename LayoutType_B,
|
||||
typename EngineType_C, typename LayoutType_C,
|
||||
typename EngineType_C_ref, typename LayoutType_C_ref>
|
||||
void
|
||||
auto
|
||||
print_matrix_multiply_mollified_relative_error(
|
||||
char const A_value_type_name[],
|
||||
cute::Tensor<EngineType_A, LayoutType_A> const& A,
|
||||
@ -158,13 +178,13 @@ print_matrix_multiply_mollified_relative_error(
|
||||
using std::cout;
|
||||
using cute::shape;
|
||||
cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
|
||||
<< "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
|
||||
<< "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
|
||||
<< std::scientific
|
||||
<< "Infinity norm of A: " << A_norm << '\n'
|
||||
<< "Infinity norm of B: " << B_norm << '\n'
|
||||
<< "Infinity norm of C: " << C_norm << '\n'
|
||||
<< "Infinity norm of (C - C_ref): " << diff_norm << '\n';
|
||||
<< "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
|
||||
<< "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
|
||||
<< std::scientific
|
||||
<< "Infinity norm of A: " << A_norm << '\n'
|
||||
<< "Infinity norm of B: " << B_norm << '\n'
|
||||
<< "Infinity norm of C: " << C_norm << '\n'
|
||||
<< "Infinity norm of (C - C_ref): " << diff_norm << '\n';
|
||||
|
||||
if(A_norm_times_B_norm == 0.0) {
|
||||
cout << "Mollified relative error: " << relative_error << '\n';
|
||||
@ -173,15 +193,16 @@ print_matrix_multiply_mollified_relative_error(
|
||||
}
|
||||
|
||||
if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
|
||||
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
|
||||
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
|
||||
}
|
||||
return relative_error;
|
||||
}
|
||||
|
||||
template <typename EngineType, typename LayoutType>
|
||||
void
|
||||
auto
|
||||
print_matrix_multiply_mollified_relative_error(
|
||||
const char value_type_name[],
|
||||
const cute::Tensor<EngineType, LayoutType>& A,
|
||||
@ -189,7 +210,7 @@ print_matrix_multiply_mollified_relative_error(
|
||||
const cute::Tensor<EngineType, LayoutType>& C_computed,
|
||||
const cute::Tensor<EngineType, LayoutType>& C_expected)
|
||||
{
|
||||
print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
|
||||
return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
|
||||
value_type_name, C_computed, C_expected);
|
||||
}
|
||||
|
||||
@ -314,7 +335,7 @@ print_relative_error(
|
||||
bool print_error = true,
|
||||
double error_margin = 0.00001) {
|
||||
assert(size(data) == size(reference));
|
||||
return print_relative_error(static_cast<std::size_t>(size(data)),
|
||||
data, reference,
|
||||
return print_relative_error(static_cast<std::size_t>(size(data)),
|
||||
data, reference,
|
||||
print_verbose, print_error, error_margin);
|
||||
}
|
||||
|
||||
@ -1713,7 +1713,7 @@ void BlockFillSequential(
|
||||
Layout layout = Layout::packed(size);
|
||||
TensorView<Element, Layout> view(ptr, layout, size);
|
||||
|
||||
Array<Element, Layout::kRank> c;
|
||||
Array<Element, Layout::kRank> c{};
|
||||
c[0] = v;
|
||||
|
||||
TensorFillLinear(view, c, s);
|
||||
|
||||
@ -41,6 +41,8 @@
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
@ -93,7 +95,8 @@ template<
|
||||
class TensorAlpha_,
|
||||
class TensorBeta_,
|
||||
class TensorBias_,
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>>
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>
|
||||
>
|
||||
struct ConvEpilogueFusionParams {
|
||||
using ElementAcc = ElementAcc_;
|
||||
using ElementScalar = ElementScalar_;
|
||||
@ -104,7 +107,6 @@ struct ConvEpilogueFusionParams {
|
||||
using TensorBeta = TensorBeta_;
|
||||
using TensorBias = TensorBias_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
@ -155,6 +157,7 @@ struct ConvReferenceImpl {
|
||||
|
||||
// Epilogue activation operation
|
||||
ActivationFunctor epi_activation;
|
||||
|
||||
ConvReferenceImpl(
|
||||
TensorA const& tensor_a,
|
||||
TensorB const& tensor_b,
|
||||
@ -201,7 +204,7 @@ private:
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
@ -226,6 +229,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 2D fprop kernel
|
||||
@ -272,6 +276,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 3D fprop kernel
|
||||
@ -325,6 +330,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 1D dgrad kernel
|
||||
@ -371,6 +377,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 2D dgrad kernel
|
||||
@ -424,11 +431,14 @@ private:
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
|
||||
tensor_d_(c, w, h, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 3D dgrad kernel
|
||||
@ -501,6 +511,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Specialization for 1D wgrad kernel
|
||||
|
||||
@ -197,7 +197,7 @@ void Depsep_Fprop(cutlass::TensorView<ElementA, LayoutA> tensor_A,
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Dgrad
|
||||
/// Dgrad / Deconv
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// dx = dgrad(dy, w)
|
||||
@ -221,7 +221,8 @@ void Conv2dDgrad(
|
||||
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
||||
TensorRef<ElementD, LayoutC> tensor_dx_out,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
ElementCompute beta,
|
||||
bool is_deconv = false) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
@ -272,7 +273,8 @@ void Conv2dDgrad(
|
||||
if (p < problem_size.P && q < problem_size.Q) {
|
||||
|
||||
ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k));
|
||||
ElementB b = tensor_w.at(cutlass::make_Coord(k, r, s, c));
|
||||
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k))
|
||||
: tensor_w.at(cutlass::make_Coord(k, r, s, c));
|
||||
|
||||
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
||||
}
|
||||
@ -420,6 +422,7 @@ void Conv2d(
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
||||
break;
|
||||
|
||||
case conv::Operator::kDeconv:
|
||||
case conv::Operator::kDgrad:
|
||||
Conv2dDgrad<
|
||||
ElementA, LayoutA,
|
||||
@ -429,7 +432,7 @@ void Conv2d(
|
||||
ElementAccumulator,
|
||||
ElementD,
|
||||
ConvertOp, InnerProductOp
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
||||
break;
|
||||
|
||||
case conv::Operator::kWgrad:
|
||||
@ -537,7 +540,7 @@ void Conv3dFprop(
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Dgrad
|
||||
/// Dgrad / Deconv
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// dx = dgrad(dy, w)
|
||||
@ -560,7 +563,8 @@ void Conv3dDgrad(
|
||||
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
||||
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
ElementCompute beta,
|
||||
bool is_deconv = false) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
@ -604,8 +608,8 @@ void Conv3dDgrad(
|
||||
if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
|
||||
|
||||
ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k));
|
||||
ElementB b = tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
|
||||
|
||||
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k))
|
||||
: tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
|
||||
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
||||
}
|
||||
}
|
||||
@ -760,6 +764,7 @@ void Conv3d(
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
||||
break;
|
||||
|
||||
case conv::Operator::kDeconv:
|
||||
case conv::Operator::kDgrad:
|
||||
Conv3dDgrad<
|
||||
ElementA, LayoutA,
|
||||
@ -768,7 +773,7 @@ void Conv3d(
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ConvertOp, InnerProductOp
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
||||
break;
|
||||
|
||||
case conv::Operator::kWgrad:
|
||||
|
||||
@ -35,10 +35,11 @@
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
@ -115,7 +116,6 @@ struct GettEpilogueParams {
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
|
||||
Reference in New Issue
Block a user