CUTLASS 3.5.0 (#1411)
This commit is contained in:
@ -97,6 +97,11 @@ template <typename OperatorClass> struct ArchMap<arch::Sm86, OperatorClass> {
|
||||
static int const kMax = 1024;
|
||||
};
|
||||
|
||||
template <typename OperatorClass> struct ArchMap<arch::Sm89, OperatorClass> {
|
||||
static int const kMin = 89;
|
||||
static int const kMax = 89;
|
||||
};
|
||||
|
||||
template <typename OperatorClass> struct ArchMap<arch::Sm90, OperatorClass> {
|
||||
static int const kMin = 90;
|
||||
static int const kMax = 1024;
|
||||
|
||||
@ -121,47 +121,47 @@ public:
|
||||
struct GemmConfiguration {
|
||||
|
||||
/// GEMM problem size
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
int64_t ldc{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
/// Number of partitions of K dimension
|
||||
int split_k_slices;
|
||||
int split_k_slices{0};
|
||||
};
|
||||
|
||||
/// Arguments for GEMM
|
||||
struct GemmArguments {
|
||||
|
||||
/// Pointer to A matrix
|
||||
void const *A;
|
||||
void const *A{nullptr};
|
||||
|
||||
/// Pointer to B matrix
|
||||
void const *B;
|
||||
void const *B{nullptr};
|
||||
|
||||
/// Pointer to C matrix
|
||||
void const *C;
|
||||
void const *C{nullptr};
|
||||
|
||||
/// Pointer to D matrix
|
||||
void *D;
|
||||
void *D{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -174,34 +174,34 @@ struct GemmArguments {
|
||||
struct GemmBatchedConfiguration {
|
||||
|
||||
/// GEMM problem size
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
int64_t ldc{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
/// Stride between instances of the A matrix in memory
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_A{0};
|
||||
|
||||
/// Stride between instances of the B matrix in memory
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_B{0};
|
||||
|
||||
/// Stride between instances of the C matrix in memory
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C{0};
|
||||
|
||||
/// Stride between instances of the D matrix in memory
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_D{0};
|
||||
|
||||
/// Number of GEMMs in batch
|
||||
int batch_count;
|
||||
int batch_count{1};
|
||||
};
|
||||
|
||||
/// Arguments to batched GEMM
|
||||
@ -216,32 +216,32 @@ using GemmBatchedArguments = GemmArguments;
|
||||
|
||||
struct GemmArrayConfiguration {
|
||||
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
int64_t ldc{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
int batch_count;
|
||||
int batch_count{1};
|
||||
};
|
||||
|
||||
/// Arguments for GEMM - used by all the GEMM operations
|
||||
struct GemmArrayArguments {
|
||||
void const * const *A;
|
||||
void const * const *B;
|
||||
void const * const *C;
|
||||
void * const *D;
|
||||
void const *alpha;
|
||||
void const *beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
void const * const *A{nullptr};
|
||||
void const * const *B{nullptr};
|
||||
void const * const *C{nullptr};
|
||||
void * const *D{nullptr};
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -253,45 +253,45 @@ struct GemmArrayArguments {
|
||||
|
||||
struct GemmUniversalConfiguration {
|
||||
|
||||
GemmUniversalMode mode;
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
GemmUniversalMode mode{GemmUniversalMode::kGemm};
|
||||
gemm::GemmCoord problem_size{};
|
||||
int batch_count{1};
|
||||
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
int64_t ldd;
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
int64_t ldd{0};
|
||||
};
|
||||
|
||||
struct GemmUniversalArguments {
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
gemm::GemmCoord problem_size{};
|
||||
int batch_count{1};
|
||||
|
||||
void const *A;
|
||||
void const *B;
|
||||
void const *C;
|
||||
void *D;
|
||||
void const *A{nullptr};
|
||||
void const *B{nullptr};
|
||||
void const *C{nullptr};
|
||||
void *D{nullptr};
|
||||
|
||||
void const *alpha;
|
||||
void const *beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
int64_t ldd;
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
int64_t ldd{0};
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_C{0};
|
||||
int64_t batch_stride_D{0};
|
||||
|
||||
// Needed for some 3.x kernels
|
||||
int sm_count;
|
||||
int sm_count{0};
|
||||
|
||||
library::RasterOrder raster_order;
|
||||
library::RasterOrder raster_order{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -303,53 +303,42 @@ struct GemmUniversalArguments {
|
||||
|
||||
struct GemmPlanarComplexConfiguration {
|
||||
|
||||
GemmUniversalMode mode;
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
int64_t lda_real;
|
||||
int64_t lda_imag;
|
||||
|
||||
int64_t ldb_real;
|
||||
int64_t ldb_imag;
|
||||
|
||||
int64_t ldc_real;
|
||||
int64_t ldc_imag;
|
||||
|
||||
int64_t ldd_real;
|
||||
int64_t ldd_imag;
|
||||
GemmUniversalMode mode{GemmUniversalMode::kGemm};
|
||||
gemm::GemmCoord problem_size{};
|
||||
int batch_count{1};
|
||||
int64_t lda_real{0};
|
||||
int64_t lda_imag{0};
|
||||
int64_t ldb_real{0};
|
||||
int64_t ldb_imag{0};
|
||||
int64_t ldc_real{0};
|
||||
int64_t ldc_imag{0};
|
||||
int64_t ldd_real{0};
|
||||
int64_t ldd_imag{0};
|
||||
};
|
||||
|
||||
/// Arguments for planar complex GEMMs
|
||||
struct GemmPlanarComplexArguments {
|
||||
|
||||
void const *A_real;
|
||||
void const *A_imag;
|
||||
void const *A_real{nullptr};
|
||||
void const *A_imag{nullptr};
|
||||
void const *B_real{nullptr};
|
||||
void const *B_imag{nullptr};
|
||||
void const *C_real{nullptr};
|
||||
void const *C_imag{nullptr};
|
||||
void *D_real{nullptr};
|
||||
void *D_imag{nullptr};
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
void const *B_real;
|
||||
void const *B_imag;
|
||||
|
||||
void const *C_real;
|
||||
void const *C_imag;
|
||||
|
||||
void *D_real;
|
||||
void *D_imag;
|
||||
|
||||
void const *alpha;
|
||||
void const *beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
|
||||
int64_t batch_stride_A_real;
|
||||
int64_t batch_stride_A_imag;
|
||||
|
||||
int64_t batch_stride_B_real;
|
||||
int64_t batch_stride_B_imag;
|
||||
|
||||
int64_t batch_stride_C_real;
|
||||
int64_t batch_stride_C_imag;
|
||||
|
||||
int64_t batch_stride_D_real;
|
||||
int64_t batch_stride_D_imag;
|
||||
int64_t batch_stride_A_real{0};
|
||||
int64_t batch_stride_A_imag{0};
|
||||
int64_t batch_stride_B_real{0};
|
||||
int64_t batch_stride_B_imag{0};
|
||||
int64_t batch_stride_C_real{0};
|
||||
int64_t batch_stride_C_imag{0};
|
||||
int64_t batch_stride_D_real{0};
|
||||
int64_t batch_stride_D_imag{0};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -358,41 +347,38 @@ struct GemmPlanarComplexArguments {
|
||||
/// from memory.
|
||||
struct GemmPlanarComplexArrayConfiguration {
|
||||
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
gemm::GemmCoord problem_size{};
|
||||
int batch_count{1};
|
||||
|
||||
int64_t lda_real;
|
||||
int64_t lda_imag;
|
||||
|
||||
int64_t ldb_real;
|
||||
int64_t ldb_imag;
|
||||
|
||||
int64_t ldc_real;
|
||||
int64_t ldc_imag;
|
||||
|
||||
int64_t ldd_real;
|
||||
int64_t ldd_imag;
|
||||
int64_t lda_real{0};
|
||||
int64_t lda_imag{0};
|
||||
int64_t ldb_real{0};
|
||||
int64_t ldb_imag{0};
|
||||
int64_t ldc_real{0};
|
||||
int64_t ldc_imag{0};
|
||||
int64_t ldd_real{0};
|
||||
int64_t ldd_imag{0};
|
||||
};
|
||||
|
||||
/// Arguments for planar complex GEMMs
|
||||
struct GemmPlanarComplexArrayArguments {
|
||||
|
||||
int const *M;
|
||||
int const *N;
|
||||
int const *K;
|
||||
int const *M{nullptr};
|
||||
int const *N{nullptr};
|
||||
int const *K{nullptr};
|
||||
|
||||
void const * const * A_real;
|
||||
void const * const * A_imag;
|
||||
void const * const * B_real;
|
||||
void const * const * B_imag;
|
||||
void const * const * C_real;
|
||||
void const * const * C_imag;
|
||||
void * const * D_real;
|
||||
void * const * D_imag;
|
||||
void const * const * A_real{nullptr};
|
||||
void const * const * A_imag{nullptr};
|
||||
void const * const * B_real{nullptr};
|
||||
void const * const * B_imag{nullptr};
|
||||
void const * const * C_real{nullptr};
|
||||
void const * const * C_imag{nullptr};
|
||||
void * const * D_real{nullptr};
|
||||
void * const * D_imag{nullptr};
|
||||
|
||||
void const * alpha;
|
||||
void const * beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
void const * alpha{nullptr};
|
||||
void const * beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -403,29 +389,27 @@ struct GemmPlanarComplexArrayArguments {
|
||||
// GemmKind: Grouped
|
||||
|
||||
struct GemmGroupedConfiguration {
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
int problem_count{0};
|
||||
int threadblock_count{0};
|
||||
};
|
||||
|
||||
struct GemmGroupedArguments {
|
||||
|
||||
gemm::GemmCoord *problem_sizes;
|
||||
gemm::GemmCoord *problem_sizes{nullptr};
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
void * ptr_A{nullptr};
|
||||
void * ptr_B{nullptr};
|
||||
void * ptr_C{nullptr};
|
||||
void * ptr_D{nullptr};
|
||||
|
||||
int64_t *lda;
|
||||
int64_t *ldb;
|
||||
int64_t *ldc;
|
||||
int64_t *ldd;
|
||||
int64_t *lda{nullptr};
|
||||
int64_t *ldb{nullptr};
|
||||
int64_t *ldc{nullptr};
|
||||
int64_t *ldd{nullptr};
|
||||
|
||||
void const *alpha;
|
||||
void const *beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -436,35 +420,31 @@ struct GemmGroupedArguments {
|
||||
/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity.
|
||||
struct SparseGemmConfiguration {
|
||||
|
||||
GemmUniversalMode mode;
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count; /// number of sparse matrix products in batch
|
||||
|
||||
int64_t lda; /// leading dimension of A operand
|
||||
int64_t ldb; /// leading dimension of B operand
|
||||
int64_t ldc; /// leading dimension of C operand
|
||||
int64_t ldd; /// leading dimension of D operand
|
||||
int64_t lde; /// leading dimension of E operand (metadata matrix)
|
||||
|
||||
int64_t batch_stride_A; // stride between matrices
|
||||
int64_t batch_stride_B; // stride between matrices
|
||||
int64_t batch_stride_C; // stride between matrices
|
||||
int64_t batch_stride_D; // stride between matrices
|
||||
int64_t batch_stride_E; // stride between matrices
|
||||
GemmUniversalMode mode{GemmUniversalMode::kGemm};
|
||||
gemm::GemmCoord problem_size{};
|
||||
int batch_count{1}; /// number of sparse matrix products in batch
|
||||
int64_t lda{0}; /// leading dimension of A operand
|
||||
int64_t ldb{0}; /// leading dimension of B operand
|
||||
int64_t ldc{0}; /// leading dimension of C operand
|
||||
int64_t ldd{0}; /// leading dimension of D operand
|
||||
int64_t lde{0}; /// leading dimension of E operand (metadata matrix)
|
||||
int64_t batch_stride_A{0}; // stride between matrices
|
||||
int64_t batch_stride_B{0}; // stride between matrices
|
||||
int64_t batch_stride_C{0}; // stride between matrices
|
||||
int64_t batch_stride_D{0}; // stride between matrices
|
||||
int64_t batch_stride_E{0}; // stride between matrices
|
||||
};
|
||||
|
||||
/// Arguments for sparse GEMMs
|
||||
struct SparseGemmArguments {
|
||||
|
||||
void const *A; /// pointer to A matrix
|
||||
void const *B; /// pointer to B matrix
|
||||
void const *C; /// pointer to C matrix
|
||||
void *D; /// pointer to D matrix
|
||||
void const *E; /// pointer to E matrix (metadata)
|
||||
|
||||
void const *alpha; /// pointer to alpha scalar
|
||||
void const *beta; /// pointer to beta scalar
|
||||
ScalarPointerMode pointer_mode; /// enumerant indicating whether alpha/beta pointers are host
|
||||
void const *A{nullptr}; /// pointer to A matrix
|
||||
void const *B{nullptr}; /// pointer to B matrix
|
||||
void const *C{nullptr}; /// pointer to C matrix
|
||||
void *D{nullptr}; /// pointer to D matrix
|
||||
void const *E{nullptr}; /// pointer to E matrix (metadata)
|
||||
void const *alpha{nullptr}; /// pointer to alpha scalar
|
||||
void const *beta{nullptr}; /// pointer to beta scalar
|
||||
ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host
|
||||
/// or device pointers.
|
||||
};
|
||||
|
||||
@ -478,52 +458,52 @@ struct SparseGemmArguments {
|
||||
struct RankKConfiguration {
|
||||
|
||||
/// SYRK problem size
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
int64_t ldc{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
/// Batch Count
|
||||
int batch_count;
|
||||
int batch_count{1};
|
||||
};
|
||||
|
||||
/// Arguments for (Syrk, Herk, Syr2k, Her2k)
|
||||
struct RankKArguments {
|
||||
|
||||
/// Pointer to A matrix
|
||||
void const *A;
|
||||
void const *A{nullptr};
|
||||
|
||||
/// Pointer to B matrix (used only for Syr2k and Her2k)
|
||||
void const *B;
|
||||
void const *B{nullptr};
|
||||
|
||||
/// Pointer to C matrix
|
||||
void const *C;
|
||||
void const *C{nullptr};
|
||||
|
||||
/// Pointer to D matrix
|
||||
void *D;
|
||||
void *D{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_C{0};
|
||||
int64_t batch_stride_D{0};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -536,45 +516,45 @@ struct RankKArguments {
|
||||
struct TrmmConfiguration {
|
||||
|
||||
/// TRMM problem size
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
/// Batch Count
|
||||
int batch_count;
|
||||
int batch_count{1};
|
||||
};
|
||||
|
||||
/// Arguments for TRMM
|
||||
struct TrmmArguments {
|
||||
|
||||
/// Pointer to A matrix
|
||||
void const *A;
|
||||
void const *A{nullptr};
|
||||
|
||||
/// Pointer to B matrix
|
||||
void const *B;
|
||||
void const *B{nullptr};
|
||||
|
||||
/// Pointer to D matrix
|
||||
void *D;
|
||||
void *D{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_D{0};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -587,52 +567,52 @@ struct TrmmArguments {
|
||||
struct SymmConfiguration {
|
||||
|
||||
/// SYMM/HEMM problem size
|
||||
gemm::GemmCoord problem_size;
|
||||
gemm::GemmCoord problem_size{};
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
int64_t lda{0};
|
||||
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
int64_t ldb{0};
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
int64_t ldc{0};
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
|
||||
/// Batch Count
|
||||
int batch_count;
|
||||
int batch_count{1};
|
||||
};
|
||||
|
||||
/// Arguments for (Symm, Hemm)
|
||||
struct SymmArguments {
|
||||
|
||||
/// Pointer to A matrix
|
||||
void const *A;
|
||||
void const *A{nullptr};
|
||||
|
||||
/// Pointer to B matrix
|
||||
void const *B;
|
||||
void const *B{nullptr};
|
||||
|
||||
/// Pointer to C matrix
|
||||
void const *C;
|
||||
void const *C{nullptr};
|
||||
|
||||
/// Pointer to D matrix
|
||||
void *D;
|
||||
void *D{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_C{0};
|
||||
int64_t batch_stride_D{0};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -649,16 +629,16 @@ struct Conv2dConfiguration {
|
||||
/// Conv2d problem size
|
||||
// contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode)
|
||||
// also includes (split_k_slices, groups)
|
||||
conv::Conv2dProblemSize problem_size;
|
||||
conv::Conv2dProblemSize problem_size{};
|
||||
|
||||
// stride of operand A
|
||||
std::vector<int64_t> stride_a;
|
||||
std::vector<int64_t> stride_a{};
|
||||
|
||||
// stride of operand B
|
||||
std::vector<int64_t> stride_b;
|
||||
std::vector<int64_t> stride_b{};
|
||||
|
||||
// stride of operand C
|
||||
std::vector<int64_t> stride_c;
|
||||
std::vector<int64_t> stride_c{};
|
||||
};
|
||||
|
||||
|
||||
@ -668,24 +648,24 @@ struct Conv2dConfiguration {
|
||||
//
|
||||
struct Conv3dConfiguration {
|
||||
|
||||
conv::SplitKMode split_k_mode;
|
||||
conv::SplitKMode split_k_mode{};
|
||||
|
||||
/// Conv2d problem size
|
||||
// contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode)
|
||||
// also includes (split_k_slices, groups)
|
||||
conv::Conv3dProblemSize problem_size;
|
||||
conv::Conv3dProblemSize problem_size{};
|
||||
|
||||
/// Layout object for activations tensor
|
||||
layout::TensorNDHWC layout_activations;
|
||||
layout::TensorNDHWC layout_activations{};
|
||||
|
||||
/// Layout object for filters tensor
|
||||
layout::TensorNDHWC layout_filters;
|
||||
layout::TensorNDHWC layout_filters{};
|
||||
|
||||
/// Layout object for source tensor
|
||||
layout::TensorNDHWC layout_source;
|
||||
layout::TensorNDHWC layout_source{};
|
||||
|
||||
/// Layout object for output tensor
|
||||
layout::TensorNDHWC layout_output;
|
||||
layout::TensorNDHWC layout_output{};
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -727,29 +707,28 @@ struct ConvArguments {
|
||||
/// ImplicitGemm matrices A, B, C, D
|
||||
/////////////////////////////////////////////////////////
|
||||
/// pointer to implicit gemm matrix A
|
||||
void const *A;
|
||||
void const *A{nullptr};
|
||||
|
||||
/// pointer to implicit gemm matrix B
|
||||
void const *B;
|
||||
void const *B{nullptr};
|
||||
|
||||
/// pointer to reordered matrix B
|
||||
void const *reordered_B;
|
||||
void const *reordered_B{nullptr};
|
||||
|
||||
/// pointer to implicit gemm matrix C
|
||||
void const *C;
|
||||
void const *C{nullptr};
|
||||
|
||||
/// pointer to implicit gemm destination matrix D
|
||||
void *D;
|
||||
void *D{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -761,47 +740,47 @@ struct ConvArguments {
|
||||
struct ReductionConfiguration {
|
||||
|
||||
/// Reduction problem size
|
||||
MatrixCoord problem_size;
|
||||
MatrixCoord problem_size{};
|
||||
|
||||
/// Number of partitions to reduce
|
||||
int partitions;
|
||||
int partitions{0};
|
||||
|
||||
/// Number of elements between each partition
|
||||
int64_t partition_stride;
|
||||
int64_t partition_stride{0};
|
||||
|
||||
/// leading dimension of 'w'orkspace operand
|
||||
int64_t ldw;
|
||||
int64_t ldw{0};
|
||||
|
||||
/// leading dimension of 's'ource operand
|
||||
int64_t lds;
|
||||
int64_t lds{0};
|
||||
|
||||
/// leading dimension of 'd'estination operand
|
||||
int64_t ldd;
|
||||
int64_t ldd{0};
|
||||
};
|
||||
|
||||
/// Arguments for Reduction
|
||||
struct ReductionArguments {
|
||||
|
||||
/// Pointer to workspace matrix
|
||||
void const *workspace;
|
||||
void const *workspace{nullptr};
|
||||
|
||||
/// Pointer to source matrix
|
||||
void const *source;
|
||||
void const *source{nullptr};
|
||||
|
||||
/// Pointer to destination matrix
|
||||
void *destination;
|
||||
void *destination{nullptr};
|
||||
|
||||
/// pointer to reference matrix
|
||||
void *reference;
|
||||
void *reference{nullptr};
|
||||
|
||||
/// Host or device pointer to alpha scalar
|
||||
void const *alpha;
|
||||
void const *alpha{nullptr};
|
||||
|
||||
/// Host or device pointer to beta scalar
|
||||
void const *beta;
|
||||
void const *beta{nullptr};
|
||||
|
||||
/// Enumerant indicating whether alpha/beta point to host or device memory
|
||||
ScalarPointerMode pointer_mode;
|
||||
ScalarPointerMode pointer_mode{};
|
||||
};
|
||||
|
||||
} // namespace library
|
||||
|
||||
865
tools/library/src/conv_operation_3x.hpp
Normal file
865
tools/library/src/conv_operation_3x.hpp
Normal file
@ -0,0 +1,865 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines operations for all CONV operation kinds in CUTLASS Library.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL)
|
||||
#include <sstream>
|
||||
#endif
|
||||
|
||||
namespace cutlass::library {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class ValueType, size_t ... Indices>
|
||||
constexpr cute::array<ValueType, 1u + sizeof...(Indices)>
|
||||
vector_to_array_strides_helper(const std::vector<ValueType>& v,
|
||||
std::index_sequence<Indices...>)
|
||||
{
|
||||
return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)};
|
||||
}
|
||||
|
||||
template<class ValueType, size_t Size>
|
||||
cute::array<ValueType, Size>
|
||||
vector_to_array_strides(const std::vector<ValueType>& v, std::integral_constant<size_t, Size>)
|
||||
{
|
||||
static_assert(Size != 0);
|
||||
CUTLASS_ASSERT(v.size() + 1u == Size);
|
||||
return vector_to_array_strides_helper(v, std::make_index_sequence<Size - 1u>{});
|
||||
}
|
||||
|
||||
template<class Index, class LongIndex, size_t ... Indices>
|
||||
constexpr cute::array<int64_t, 1u + sizeof...(Indices)>
|
||||
coord_to_array_strides_helper(
|
||||
const ::cutlass::Coord<int(sizeof...(Indices)), Index, LongIndex> coord,
|
||||
std::index_sequence<Indices...>)
|
||||
{
|
||||
return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)};
|
||||
}
|
||||
|
||||
template<int Rank, class Index, class LongIndex>
|
||||
cute::array<int64_t, 1u + size_t(Rank)>
|
||||
coord_to_array_strides(const ::cutlass::Coord<Rank, Index, LongIndex>& coord)
|
||||
{
|
||||
static_assert(Rank >= 0);
|
||||
return coord_to_array_strides_helper(coord, std::make_index_sequence<Rank>{});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions.
|
||||
// For CUTLASS 2's 2-D convolutions, see Conv2dOperation.
|
||||
// For CUTLASS 2's 3-D convolutions, see Conv3dOperation.
|
||||
template<class Operator_>
|
||||
class ConvOperation3x : public Operation {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
|
||||
static_assert(Operator::NumSpatialDimensions == 2 ||
|
||||
Operator::NumSpatialDimensions == 3,
|
||||
"The profiler currently only supports convolutions with 2 or 3 spatial dimensions.");
|
||||
using LayoutA = cute::conditional_t<Operator::NumSpatialDimensions == 3,
|
||||
cutlass::layout::TensorNDHWC,
|
||||
cute::conditional_t<Operator::NumSpatialDimensions == 2,
|
||||
cutlass::layout::TensorNHWC,
|
||||
cutlass::layout::TensorNWC>
|
||||
>;
|
||||
using LayoutB = LayoutA;
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementA = typename Operator::ElementA;
|
||||
using ElementB = typename Operator::ElementB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using ElementD = typename Operator::ElementD;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator;
|
||||
|
||||
ConvOperation3x(const char* name = "unknown_cutlass_3_conv") {
|
||||
// Initialize OperationDescription (the base class)
|
||||
description_.name = name;
|
||||
description_.provider = Provider::kCUTLASS;
|
||||
|
||||
if constexpr (Operator::NumSpatialDimensions == 2) {
|
||||
description_.kind = OperationKind::kConv2d;
|
||||
}
|
||||
else if constexpr (Operator::NumSpatialDimensions == 3) {
|
||||
description_.kind = OperationKind::kConv3d;
|
||||
}
|
||||
else {
|
||||
static_assert(::cutlass::detail::dependent_false<Operator>,
|
||||
"This class currently only supports 2-D and 3-D convolutions.");
|
||||
}
|
||||
|
||||
description_.tile_description.threadblock_shape = make_Coord(
|
||||
Operator::ThreadblockShape::kM,
|
||||
Operator::ThreadblockShape::kN,
|
||||
Operator::ThreadblockShape::kK);
|
||||
|
||||
description_.tile_description.threadblock_stages = Operator::kStages;
|
||||
|
||||
description_.tile_description.warp_count = make_Coord(
|
||||
Operator::WarpCount::kM,
|
||||
Operator::WarpCount::kN,
|
||||
Operator::WarpCount::kK);
|
||||
|
||||
description_.tile_description.math_instruction.instruction_shape = make_Coord(
|
||||
Operator::InstructionShape::kM,
|
||||
Operator::InstructionShape::kN,
|
||||
Operator::InstructionShape::kK);
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.opcode_class =
|
||||
OpcodeClassMap<typename Operator::OperatorClass>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.math_operation =
|
||||
MathOperationID::kMultiplyAdd;
|
||||
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
|
||||
|
||||
description_.tile_description.maximum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
|
||||
|
||||
// Initialize ConvDescription (the subclass)
|
||||
|
||||
// kConvDim does not exist in Operator for CUTLASS 3 convolutions.
|
||||
// For CUTLASS 2 convolutions, it is the number of spatial dimensions.
|
||||
description_.conv_dim = Operator::NumSpatialDimensions;
|
||||
description_.conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
|
||||
|
||||
description_.iterator_algorithm = {};
|
||||
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>();
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>();
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>();
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
}
|
||||
|
||||
~ConvOperation3x() override = default;
|
||||
|
||||
OperationDescription const& description() const override {
|
||||
return static_cast<OperationDescription const&>(description_);
|
||||
}
|
||||
|
||||
private:
|
||||
Status update_operator_arguments_from_configuration_2d_or_3d(
|
||||
typename Operator::Arguments& out_args,
|
||||
void const* configuration) const {
|
||||
Status status = Status::kInvalid;
|
||||
|
||||
CUTLASS_ASSERT(configuration != nullptr);
|
||||
|
||||
if constexpr (Operator::NumSpatialDimensions == 2) {
|
||||
CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d);
|
||||
// tools/library/include/cutlass/library/library.h
|
||||
// defines Conv2dConfiguration.
|
||||
// tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h
|
||||
// uses Conv2dConfiguration.
|
||||
auto* conf_ptr = reinterpret_cast<Conv2dConfiguration const*>(configuration);
|
||||
status = update_operator_arguments_from_configuration(out_args, *conf_ptr);
|
||||
}
|
||||
else if constexpr (Operator::NumSpatialDimensions == 3) {
|
||||
CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d);
|
||||
auto* conf_ptr = reinterpret_cast<Conv3dConfiguration const*>(configuration);
|
||||
status = update_operator_arguments_from_configuration(out_args, *conf_ptr);
|
||||
}
|
||||
else {
|
||||
static_assert(::cutlass::detail::dependent_false<Operator>,
|
||||
"This class currently only supports 2-D and 3-D convolutions.");
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
public:
|
||||
Status can_implement(
|
||||
void const* configuration,
|
||||
void const* arguments) const override {
|
||||
Status status = Status::kInvalid;
|
||||
|
||||
// gemm_operation_3x.hpp accesses "configuration" as
|
||||
// GemmUniversalConfiguration (which lives in
|
||||
// tools/library/include/cutlass/library/library.h) and
|
||||
// "arguments" as GemmUniversalArguments (which lives in
|
||||
// tools/library/include/cutlass/library/library.h).
|
||||
// Those things don't apply to convolutions.
|
||||
// Despite the existence of ConvUniversal, there's no
|
||||
// corresponding "ConvUniversalConfiguration" or
|
||||
// "ConvUniversalArguments."
|
||||
|
||||
CUTLASS_ASSERT(configuration != nullptr);
|
||||
CUTLASS_ASSERT(arguments != nullptr);
|
||||
|
||||
typename Operator::Arguments out_args{};
|
||||
status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto* in_args_ptr = reinterpret_cast<ConvArguments const*>(arguments);
|
||||
status = update_operator_arguments_from_arguments(out_args, *in_args_ptr);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Operator::can_implement(out_args);
|
||||
}
|
||||
|
||||
uint64_t get_host_workspace_size(void const* /* configuration */) const override {
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
uint64_t get_device_workspace_size(
|
||||
void const* configuration,
|
||||
void const* arguments = nullptr) const override
|
||||
{
|
||||
// This presumes that at least one of configuration or arguments is nonnull.
|
||||
Status status = Status::kInvalid;
|
||||
|
||||
// gemm_operation_3x.hpp has get_device_workspace_size return 0 on
|
||||
// error. It's not clear that this is what we want -- perhaps we
|
||||
// should return something like expected<uint64_t, Status>? -- but
|
||||
// it's the only option that preserves the current interface.
|
||||
constexpr uint64_t error_indication = 0;
|
||||
|
||||
typename Operator::Arguments out_args{};
|
||||
if (configuration != nullptr) {
|
||||
status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration);
|
||||
if (status != Status::kSuccess) {
|
||||
return error_indication;
|
||||
}
|
||||
}
|
||||
if (arguments != nullptr) {
|
||||
auto* in_args_ptr = reinterpret_cast<ConvArguments const*>(arguments);
|
||||
status = update_operator_arguments_from_arguments(out_args, *in_args_ptr);
|
||||
if (status != Status::kSuccess) {
|
||||
return error_indication;
|
||||
}
|
||||
}
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
return static_cast<uint64_t>(Operator::get_workspace_size(out_args));
|
||||
}
|
||||
else {
|
||||
return error_indication;
|
||||
}
|
||||
}
|
||||
|
||||
Status initialize(
|
||||
void const* configuration,
|
||||
void* host_workspace,
|
||||
void* /* device_workspace */ = nullptr,
|
||||
cudaStream_t stream = nullptr) const override
|
||||
{
|
||||
Status status = Status::kInvalid;
|
||||
|
||||
if (configuration == nullptr) {
|
||||
CUTLASS_TRACE_HOST("Input configuration is null.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
typename Operator::Arguments out_args{};
|
||||
status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration);
|
||||
if (status != Status::kSuccess) {
|
||||
// Any kind of failure invalidates the last successful configuration.
|
||||
clear_last_successful_config();
|
||||
return status;
|
||||
}
|
||||
else {
|
||||
set_last_successful_config(configuration);
|
||||
}
|
||||
|
||||
if (host_workspace == nullptr) {
|
||||
CUTLASS_TRACE_HOST("host_workspace is null.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
(void) new (host_workspace) Operator;
|
||||
return status;
|
||||
|
||||
// CUTLASS 2 convolutions call the Operator's initialize function
|
||||
// here, like this.
|
||||
//
|
||||
//return op->initialize(args, device_workspace, stream);
|
||||
//
|
||||
// CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms
|
||||
// (GemmUniversal), lack an "initialize" member function.
|
||||
}
|
||||
|
||||
Status run(
|
||||
void const* arguments,
|
||||
void* host_workspace,
|
||||
void* device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override
|
||||
{
|
||||
auto status = Status::kInvalid;
|
||||
|
||||
// The Operator doesn't appear to save the last configuration (it
|
||||
// doesn't have a way to do that, since it lacks an initialize()
|
||||
// member function), so we have to use the stored configuration
|
||||
// from the last successful initialize() call (if any).
|
||||
typename Operator::Arguments out_args{};
|
||||
status = update_operator_arguments_from_stored_configuration(out_args);
|
||||
if (status != Status::kSuccess) {
|
||||
CUTLASS_TRACE_HOST("Updating from previous successful configuration failed.");
|
||||
return status;
|
||||
}
|
||||
|
||||
if (arguments == nullptr) {
|
||||
CUTLASS_TRACE_HOST("Input argument 'arguments' is null.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
auto* in_args_ptr = reinterpret_cast<ConvArguments const*>(arguments);
|
||||
status = update_operator_arguments_from_arguments(out_args, *in_args_ptr);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto* op = reinterpret_cast<Operator*>(host_workspace);
|
||||
return op->run(out_args, device_workspace, stream);
|
||||
}
|
||||
|
||||
private:
|
||||
ConvDescription description_;
|
||||
// Result of initialize() calling
|
||||
// update_operator_arguments_from_configuration() successfully.
|
||||
// This is needed because run() doesn't take a configuration, just
|
||||
// arguments, and the kernel doesn't appear to save the
|
||||
// configuration from the last initialize() call.
|
||||
//
|
||||
// Unfortunately, this must be declared mutable, because it must be
|
||||
// set in initialize(), and initialize() is inherited as const.
|
||||
mutable std::variant<
|
||||
std::monostate,
|
||||
Conv2dConfiguration,
|
||||
Conv3dConfiguration> last_successful_config_{std::monostate{}};
|
||||
|
||||
// Clear the last configuration resulting from a successful initialize() call.
|
||||
//
|
||||
// Unfortunately, this must be declared const, because initialize() is.
|
||||
void clear_last_successful_config() const {
|
||||
last_successful_config_ = std::monostate{};
|
||||
}
|
||||
|
||||
// Set the last configuration resulting from a successful initialize() call.
|
||||
//
|
||||
// Unfortunately, this must be declared const, because initialize() is.
|
||||
void set_last_successful_config(void const* configuration) const {
|
||||
CUTLASS_ASSERT(configuration != nullptr);
|
||||
|
||||
if constexpr (Operator::NumSpatialDimensions == 2) {
|
||||
CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d);
|
||||
auto* conf_ptr = reinterpret_cast<Conv2dConfiguration const*>(configuration);
|
||||
last_successful_config_ = *conf_ptr;
|
||||
} else if constexpr (Operator::NumSpatialDimensions == 3) {
|
||||
CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d);
|
||||
auto* conf_ptr = reinterpret_cast<Conv3dConfiguration const*>(configuration);
|
||||
last_successful_config_ = *conf_ptr;
|
||||
}
|
||||
else {
|
||||
static_assert(::cutlass::detail::dependent_false<Operator>,
|
||||
"This class currently only supports 2-D and 3-D convolutions.");
|
||||
}
|
||||
}
|
||||
|
||||
// Whether a configuration from a successful initialize() call exists.
|
||||
bool last_successful_config_exists() const {
|
||||
return not std::holds_alternative<std::monostate>(last_successful_config_);
|
||||
}
|
||||
|
||||
// Visitor for update_operator_arguments_from_stored_configuration.
|
||||
struct ConfigurationVisitor {
|
||||
typename Operator::Arguments& out_args;
|
||||
|
||||
Status operator() (std::monostate const&) const {
|
||||
CUTLASS_TRACE_HOST("No successful previous configuration exists. "
|
||||
"One cause is calling run() before a successful initialize() call.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
Status operator() (Conv2dConfiguration const& conf2d) const {
|
||||
return update_operator_arguments_from_configuration(out_args, conf2d);
|
||||
}
|
||||
Status operator() (Conv3dConfiguration const& conf3d) const {
|
||||
return update_operator_arguments_from_configuration(out_args, conf3d);
|
||||
}
|
||||
};
|
||||
|
||||
// Like update_operator_arguments_from_configuration, but on the
|
||||
// stored configuration from the last successful initialize() call,
|
||||
// if any. If there was no last successful initialize() call,
|
||||
// then return Status::kInvalid.
|
||||
//
|
||||
// Unfortunately, this must be declared const, because run() is.
|
||||
Status update_operator_arguments_from_stored_configuration(
|
||||
typename Operator::Arguments& out_args) const
|
||||
{
|
||||
return std::visit(ConfigurationVisitor{out_args}, last_successful_config_);
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(
|
||||
FusionArgs const&,
|
||||
ConvArguments const&)
|
||||
{
|
||||
// For custom EVT, it is the user's responsibility to ensure
|
||||
// that alpha and beta are updated appropriately.
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(
|
||||
FusionArgs& fusion_args,
|
||||
ConvArguments const& arguments)
|
||||
{
|
||||
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
|
||||
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static Status update_operator_arguments_from_configuration(
|
||||
typename Operator::Arguments& out_args,
|
||||
Conv2dConfiguration const& config)
|
||||
{
|
||||
using detail::vector_to_array_strides;
|
||||
|
||||
constexpr int num_spatial_dims = Operator::NumSpatialDimensions;
|
||||
if constexpr (num_spatial_dims != 2) {
|
||||
CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration "
|
||||
"with an Operator whose NumSpatialDimensions is exactly 2.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
else {
|
||||
// Convolutions split the metadata (in Conv2dConfiguration) from
|
||||
// the data (ConvArguments, which only has pointers and a single
|
||||
// enum value). Thus, this class will need both the
|
||||
// configuration and the (user's input) arguments to set up the
|
||||
// kernel's arguments. This function can fill in what the
|
||||
// configuration has now, but the class will need the user's
|
||||
// input arguments later.
|
||||
if (config.split_k_mode != conv::SplitKMode::kSerial) {
|
||||
CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
// config.problem_size.split_k_slices is only meaningful if
|
||||
// split_k_mode != kSerial. If this code later supports other
|
||||
// split_k_mode values, then it will also need to read
|
||||
// split_k_slices.
|
||||
|
||||
const int N = config.problem_size.N;
|
||||
const int H = config.problem_size.H;
|
||||
const int W = config.problem_size.W;
|
||||
const int C = config.problem_size.C;
|
||||
const int K = config.problem_size.K;
|
||||
const int R = config.problem_size.R;
|
||||
const int S = config.problem_size.S;
|
||||
const int pad_h = config.problem_size.pad_h;
|
||||
const int pad_w = config.problem_size.pad_w;
|
||||
const int traversal_stride_h = config.problem_size.stride_h;
|
||||
const int traversal_stride_w = config.problem_size.stride_w;
|
||||
const int dilation_h = config.problem_size.dilation_h;
|
||||
const int dilation_w = config.problem_size.dilation_w;
|
||||
|
||||
// CUTLASS 3's implicit GEMM convolution kernels currently only
|
||||
// support cross correlation (passing over the activation and
|
||||
// filter tensors in the same order). The convolution mode is
|
||||
// future work.
|
||||
const auto mode = config.problem_size.mode;
|
||||
if (mode != cutlass::conv::Mode::kCrossCorrelation) {
|
||||
CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation "
|
||||
"are not currently supported.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
constexpr int num_spatial_dims = Operator::NumSpatialDimensions;
|
||||
constexpr size_t stride_size = size_t(num_spatial_dims) + 2u;
|
||||
constexpr auto the_stride_size = std::integral_constant<size_t, stride_size>{};
|
||||
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n"
|
||||
<< " stride_size = " << stride_size << "\n";
|
||||
auto print_stride = [] (auto const& stride, char const variable_name[]) {
|
||||
std::cerr << " " << variable_name << ": [";
|
||||
for (size_t k = 0; k < stride.size(); ++k) {
|
||||
std::cerr << stride[k];
|
||||
if (k + 1u < stride.size()) {
|
||||
std::cerr << ", ";
|
||||
}
|
||||
}
|
||||
std::cerr << "]\n";
|
||||
};
|
||||
print_stride(config.stride_a, "config.stride_a");
|
||||
print_stride(config.stride_b, "config.stride_b");
|
||||
print_stride(config.stride_c, "config.stride_c");
|
||||
#endif
|
||||
|
||||
// Conv2dConfiguration stores the strides as std::vector,
|
||||
// so the code needs to check the run-time vector lengths.
|
||||
if (config.stride_a.size() + 1u != stride_size) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL)
|
||||
std::ostringstream os;
|
||||
os << "config.stride_a.size() + 1u = "
|
||||
<< (config.stride_a.size() + 1u)
|
||||
<< " != num_spatial_dims + 2u = " << stride_size;
|
||||
CUTLASS_TRACE_HOST( os.str() );
|
||||
#endif
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (config.stride_b.size() + 1u != stride_size) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL)
|
||||
std::ostringstream os;
|
||||
os << "config.stride_b.size() + 1u = "
|
||||
<< (config.stride_b.size() + 1u)
|
||||
<< " != num_spatial_dims + 2u = " << stride_size;
|
||||
CUTLASS_TRACE_HOST( os.str() );
|
||||
#endif
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (config.stride_c.size() + 1u != stride_size) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL)
|
||||
std::ostringstream os;
|
||||
os << "config.stride_c.size() + 1u = "
|
||||
<< (config.stride_c.size() + 1u)
|
||||
<< " != num_spatial_dims + 2u = " << stride_size;
|
||||
CUTLASS_TRACE_HOST( os.str() );
|
||||
#endif
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp;
|
||||
using problem_shape_type =
|
||||
cutlass::conv::ConvProblemShape<conv_op, num_spatial_dims>;
|
||||
// cute::array<int64_t, RankT>; must convert to the kernel's native strides
|
||||
using TensorStride = typename problem_shape_type::TensorStride;
|
||||
|
||||
const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size);
|
||||
const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size);
|
||||
const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size);
|
||||
|
||||
// cutlass::library::Conv2dConfiguration has no member stride_d.
|
||||
// The code below imitates the testbed,
|
||||
// which just sets D's strides to C's strides.
|
||||
const TensorStride stride_D = stride_C;
|
||||
|
||||
const int num_groups = config.problem_size.groups;
|
||||
if (num_groups != 1) {
|
||||
CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
problem_shape_type problem_shape(
|
||||
/* mode = */ mode,
|
||||
/* shape_act = */ {N, H, W, C},
|
||||
/* stride_act = */ stride_A,
|
||||
/* shape_flt = */ {K, R, S, C},
|
||||
/* stride_flt = */ stride_B,
|
||||
/* lower_padding = */ {pad_h, pad_w},
|
||||
/* upper_padding = */ {pad_h, pad_w},
|
||||
/* traversal_stride = */ {traversal_stride_h, traversal_stride_w},
|
||||
/* dilation = */ {dilation_h, dilation_w},
|
||||
num_groups);
|
||||
out_args.mainloop.problem_shape = problem_shape;
|
||||
|
||||
// ConvProblemShape's constructor sets its shape_C member.
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " problem_shape:\n"
|
||||
<< " shape_C: " << problem_shape.shape_C << "\n";
|
||||
std::cerr << " stride_C: " << problem_shape.stride_C << "\n";
|
||||
#endif
|
||||
// Initialization of C's and D's strides follows the CUTLASS 3
|
||||
// convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp).
|
||||
{
|
||||
using StrideC = typename Operator::ConvKernel::StrideC;
|
||||
using StrideD = typename Operator::ConvKernel::StrideD;
|
||||
auto stride_C = StrideC{};
|
||||
auto stride_D = StrideD{};
|
||||
|
||||
if constexpr (conv_op == cutlass::conv::Operator::kWgrad) {
|
||||
stride_C = cutlass::make_cute_packed_stride(
|
||||
StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op);
|
||||
stride_D = cutlass::make_cute_packed_stride(
|
||||
StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op);
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " Wgrad: stride_C: " << stride_C << "\n";
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): "
|
||||
<< stride_C_i << "\n";
|
||||
#endif
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): "
|
||||
<< stride_D_i << "\n";
|
||||
#endif
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
});
|
||||
}
|
||||
out_args.epilogue.dC = stride_C;
|
||||
out_args.epilogue.dD = stride_D;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
static Status update_operator_arguments_from_configuration(
|
||||
typename Operator::Arguments& out_args,
|
||||
Conv3dConfiguration const& config)
|
||||
{
|
||||
using detail::coord_to_array_strides;
|
||||
|
||||
constexpr int num_spatial_dims = Operator::NumSpatialDimensions;
|
||||
if constexpr (num_spatial_dims != 3) {
|
||||
CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration "
|
||||
"with an Operator whose NumSpatialDimensions is exactly 3.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
else {
|
||||
// Convolutions split the metadata (in Conv3dConfiguration) from
|
||||
// the data (ConvArguments, which only has pointers and a single
|
||||
// enum value). Thus, this class will need both the
|
||||
// configuration and the (user's input) arguments to set up the
|
||||
// kernel's arguments. This function can fill in what the
|
||||
// configuration has now, but the class will need the user's
|
||||
// input arguments later.
|
||||
if (config.split_k_mode != conv::SplitKMode::kSerial) {
|
||||
CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
// config.problem_size.split_k_slices is only meaningful if
|
||||
// split_k_mode != kSerial. If this code later supports other
|
||||
// split_k_mode values, then it will also need to read
|
||||
// split_k_slices.
|
||||
|
||||
const int N = config.problem_size.N;
|
||||
const int D = config.problem_size.D;
|
||||
const int H = config.problem_size.H;
|
||||
const int W = config.problem_size.W;
|
||||
const int C = config.problem_size.C;
|
||||
const int K = config.problem_size.K;
|
||||
const int T = config.problem_size.T;
|
||||
const int R = config.problem_size.R;
|
||||
const int S = config.problem_size.S;
|
||||
const int pad_d = config.problem_size.pad_d;
|
||||
const int pad_h = config.problem_size.pad_h;
|
||||
const int pad_w = config.problem_size.pad_w;
|
||||
const int traversal_stride_d = config.problem_size.stride_d;
|
||||
const int traversal_stride_h = config.problem_size.stride_h;
|
||||
const int traversal_stride_w = config.problem_size.stride_w;
|
||||
const int dilation_d = config.problem_size.dilation_d;
|
||||
const int dilation_h = config.problem_size.dilation_h;
|
||||
const int dilation_w = config.problem_size.dilation_w;
|
||||
|
||||
// CUTLASS 3's implicit GEMM convolution kernels currently only
|
||||
// support cross correlation (passing over the activation and
|
||||
// filter tensors in the same order). The convolution mode is
|
||||
// future work.
|
||||
const auto mode = config.problem_size.mode;
|
||||
if (mode != cutlass::conv::Mode::kCrossCorrelation) {
|
||||
CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation "
|
||||
"are not currently supported.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
using Stride = cutlass::layout::TensorNDHWC::Stride;
|
||||
static_assert(std::is_same_v<Stride, cutlass::Coord<4>>);
|
||||
|
||||
const cutlass::library::ConvKind conv_kind = [] () {
|
||||
constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp;
|
||||
if constexpr (op == cutlass::conv::Operator::kFprop) {
|
||||
return library::ConvKind::kFprop;
|
||||
}
|
||||
else if constexpr (op == cutlass::conv::Operator::kDgrad) {
|
||||
return library::ConvKind::kDgrad;
|
||||
}
|
||||
else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ {
|
||||
return library::ConvKind::kWgrad;
|
||||
}
|
||||
} ();
|
||||
const Stride input_stride_a = config.layout_a(conv_kind).stride();
|
||||
const Stride input_stride_b = config.layout_b(conv_kind).stride();
|
||||
const Stride input_stride_c = config.layout_c(conv_kind).stride();
|
||||
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
constexpr size_t stride_size = size_t(num_spatial_dims) + 2u;
|
||||
std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n"
|
||||
<< " stride_size = " << stride_size << "\n";
|
||||
auto print_stride = [] (Stride const& stride, char const variable_name[]) {
|
||||
std::cerr << " " << variable_name << ": [";
|
||||
for (size_t k = 0; k < Stride::kRank; ++k) {
|
||||
std::cerr << stride[static_cast<int>(k)];
|
||||
if (k + 1u < Stride::kRank) {
|
||||
std::cerr << ", ";
|
||||
}
|
||||
}
|
||||
std::cerr << "]\n";
|
||||
};
|
||||
print_stride(input_stride_a, "input_stride_a");
|
||||
print_stride(input_stride_b, "input_stride_b");
|
||||
print_stride(input_stride_c, "input_stride_c");
|
||||
#endif
|
||||
|
||||
constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp;
|
||||
using problem_shape_type =
|
||||
cutlass::conv::ConvProblemShape<conv_op, num_spatial_dims>;
|
||||
// cute::array<int64_t, RankT>; must convert to the kernel's native strides
|
||||
using TensorStride = typename problem_shape_type::TensorStride;
|
||||
|
||||
const TensorStride stride_A = coord_to_array_strides(input_stride_a);
|
||||
const TensorStride stride_B = coord_to_array_strides(input_stride_b);
|
||||
const TensorStride stride_C = coord_to_array_strides(input_stride_c);
|
||||
|
||||
const TensorStride stride_D = stride_C;
|
||||
const int num_groups = config.problem_size.groups;
|
||||
if (num_groups != 1) {
|
||||
CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
problem_shape_type problem_shape(
|
||||
/* mode = */ mode,
|
||||
/* shape_act = */ {N, D, H, W, C},
|
||||
/* stride_act = */ stride_A,
|
||||
/* shape_flt = */ {K, T, R, S, C},
|
||||
/* stride_flt = */ stride_B,
|
||||
/* lower_padding = */ {pad_d, pad_h, pad_w},
|
||||
/* upper_padding = */ {pad_d, pad_h, pad_w},
|
||||
/* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w},
|
||||
/* dilation = */ {dilation_d, dilation_h, dilation_w},
|
||||
num_groups);
|
||||
out_args.mainloop.problem_shape = problem_shape;
|
||||
|
||||
// ConvProblemShape's constructor sets its shape_C member.
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " problem_shape:\n"
|
||||
<< " shape_C: " << problem_shape.shape_C << "\n";
|
||||
std::cerr << " stride_C: " << problem_shape.stride_C << "\n";
|
||||
#endif
|
||||
|
||||
{
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " Compute stride_C and stride_D\n";
|
||||
#endif
|
||||
using StrideC = typename Operator::ConvKernel::StrideC;
|
||||
using StrideD = typename Operator::ConvKernel::StrideD;
|
||||
auto stride_C = StrideC{};
|
||||
auto stride_D = StrideD{};
|
||||
|
||||
if constexpr (conv_op == cutlass::conv::Operator::kWgrad) {
|
||||
stride_C = cutlass::make_cute_packed_stride(
|
||||
StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op);
|
||||
stride_D = cutlass::make_cute_packed_stride(
|
||||
StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op);
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " Wgrad: stride_C: " << stride_C << "\n";
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): "
|
||||
<< stride_C_i << "\n";
|
||||
#endif
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): "
|
||||
<< stride_D_i << "\n";
|
||||
#endif
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i];
|
||||
});
|
||||
}
|
||||
out_args.epilogue.dC = stride_C;
|
||||
out_args.epilogue.dD = stride_D;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
Status update_operator_arguments_from_arguments(
|
||||
typename Operator::Arguments& out_args,
|
||||
ConvArguments const& in_args) const
|
||||
{
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << "ConvOperation3x::update_operator_arguments_from_arguments\n";
|
||||
#endif
|
||||
|
||||
out_args.mainloop.ptr_A = reinterpret_cast<ElementA const*>(in_args.A);
|
||||
out_args.mainloop.ptr_B = reinterpret_cast<ElementB const*>(in_args.B);
|
||||
|
||||
out_args.epilogue.ptr_C = reinterpret_cast<ElementC const*>(in_args.C);
|
||||
out_args.epilogue.ptr_D = reinterpret_cast<ElementD*>(in_args.D);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::library
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -63,7 +64,6 @@ public:
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
private:
|
||||
|
||||
GemmDescription description_;
|
||||
|
||||
public:
|
||||
@ -215,7 +215,8 @@ protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalArguments const *arguments) {
|
||||
OperatorArguments &operator_args,
|
||||
GemmUniversalArguments const *arguments) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
@ -261,7 +262,7 @@ protected:
|
||||
operator_args.scheduler.raster_order = Enum_t::Heuristic;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
@ -1202,7 +1202,7 @@ std::string lexical_cast(int64_t int_value) {
|
||||
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
|
||||
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
||||
|
||||
int size_bytes = sizeof_bits(type) / 8;
|
||||
size_t size_bytes = sizeof_bits(type) / 8;
|
||||
|
||||
if (!size_bytes || size_bytes != bytes.size()) {
|
||||
return "<invalid>";
|
||||
|
||||
@ -158,8 +158,8 @@ using DispositionMap = std::map<library::Provider, Disposition>;
|
||||
// Print vector for the report
|
||||
template <typename T>
|
||||
std::ostream& operator<< (std::ostream& out, const std::vector<T>& v) {
|
||||
for(int i = 0; i < v.size(); ++i) {
|
||||
out << to_string(v[i], true) << (i+1 != v.size() ? "," : "");
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
out << to_string(v[i], true) << (i + 1u != v.size() ? "," : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines a math function
|
||||
\brief Gemm Profiler
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -67,23 +67,23 @@ public:
|
||||
/// Problem structure obtained from problem space
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode;
|
||||
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm};
|
||||
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
int64_t m{16};
|
||||
int64_t n{16};
|
||||
int64_t k{16};
|
||||
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
cutlass::library::SplitKMode split_k_mode;
|
||||
int split_k_slices;
|
||||
int batch_count;
|
||||
cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone};
|
||||
int split_k_slices{1};
|
||||
int batch_count{1};
|
||||
|
||||
cutlass::library::RasterOrder raster_order;
|
||||
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
|
||||
// gemm with parallel interleaved reduction
|
||||
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||
@ -94,18 +94,6 @@ public:
|
||||
// Methods
|
||||
//
|
||||
|
||||
GemmProblem():
|
||||
mode(library::GemmUniversalMode::kGemm),
|
||||
m(16),
|
||||
n(16),
|
||||
k(16),
|
||||
lda(0),
|
||||
ldb(0),
|
||||
ldc(0),
|
||||
split_k_slices(1),
|
||||
batch_count(1),
|
||||
raster_order(cutlass::library::RasterOrder::kHeuristic){ }
|
||||
|
||||
/// Parses the problem
|
||||
Status parse(
|
||||
library::GemmDescription const &operation_desc,
|
||||
@ -128,15 +116,15 @@ public:
|
||||
/// Workspace used
|
||||
struct GemmWorkspace {
|
||||
|
||||
DeviceAllocation *A;
|
||||
DeviceAllocation *B;
|
||||
DeviceAllocation *C;
|
||||
DeviceAllocation *Computed;
|
||||
DeviceAllocation *Reference;
|
||||
DeviceAllocation *A{nullptr};
|
||||
DeviceAllocation *B{nullptr};
|
||||
DeviceAllocation *C{nullptr};
|
||||
DeviceAllocation *Computed{nullptr};
|
||||
DeviceAllocation *Reference{nullptr};
|
||||
|
||||
/// Number of copies of the problem workspace which are visited sequentially during
|
||||
/// profiling to avoid camping in the last level cache.
|
||||
int problem_count;
|
||||
int problem_count{1};
|
||||
|
||||
library::GemmUniversalConfiguration configuration;
|
||||
library::GemmUniversalArguments arguments;
|
||||
@ -153,13 +141,6 @@ public:
|
||||
|
||||
/// Buffer used for the cutlass reduction operations' host workspace
|
||||
std::vector<uint8_t> reduction_host_workspace;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
GemmWorkspace():
|
||||
A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { }
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
@ -42,8 +42,8 @@
|
||||
#include "cutlass/profiler/rank_2k_operation_profiler.h"
|
||||
#include "cutlass/profiler/trmm_operation_profiler.h"
|
||||
#include "cutlass/profiler/symm_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv2d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv3d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv2d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv3d_operation_profiler.h"
|
||||
#include "cutlass/profiler/sparse_gemm_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -55,7 +55,7 @@ namespace profiler {
|
||||
|
||||
CutlassProfiler::CutlassProfiler(
|
||||
Options const &options
|
||||
):
|
||||
):
|
||||
options_(options) {
|
||||
|
||||
operation_profilers_.emplace_back(new GemmOperationProfiler(options));
|
||||
@ -145,7 +145,6 @@ int CutlassProfiler::profile_() {
|
||||
|
||||
int result = 0;
|
||||
DeviceContext device_context;
|
||||
|
||||
// For all profilers
|
||||
for (auto & profiler : operation_profilers_) {
|
||||
|
||||
@ -193,8 +192,8 @@ void CutlassProfiler::print_usage_(std::ostream &out) {
|
||||
<< " $ cutlass_profiler --operation=RankK --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Trmm --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Symm --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Conv3d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Conv2d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Conv3d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Conv2d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=SparseGemm --help\n\n"
|
||||
;
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
#include <ios>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/core_io.h"
|
||||
|
||||
@ -167,7 +168,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
// default value
|
||||
this->k = 1024;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
|
||||
// default value
|
||||
this->split_k_mode = library::SplitKMode::kSerial;
|
||||
@ -421,6 +422,7 @@ void GemmOperationProfiler::initialize_result_(
|
||||
bool GemmOperationProfiler::initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const&>(operation->description());
|
||||
|
||||
@ -577,7 +579,6 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
||||
|
||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||
|
||||
uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||
gemm_workspace_.host_workspace.resize(workspace_size, 0);
|
||||
|
||||
@ -620,7 +621,6 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
results_.back().verification_map[provider] = Disposition::kNotRun;
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -794,7 +794,6 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
library::GemmDescription const &gemm_desc =
|
||||
|
||||
@ -51,6 +51,8 @@
|
||||
#include "cutlass/profiler/operation_profiler.h"
|
||||
#include "cutlass/profiler/gpu_timer.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -66,7 +68,7 @@ OperationProfiler::OperationProfiler(
|
||||
library::OperationKind kind,
|
||||
ArgumentDescriptionVector const &arguments,
|
||||
ProviderVector const & verification_providers
|
||||
):
|
||||
):
|
||||
kind_(kind), arguments_(arguments) {
|
||||
|
||||
ArgumentDescriptionVector tile_description_arguments{
|
||||
@ -93,8 +95,8 @@ OperationProfiler::OperationProfiler(
|
||||
|
||||
for (auto provider : verification_providers) {
|
||||
if (std::find(
|
||||
options.verification.providers.begin(),
|
||||
options.verification.providers.end(),
|
||||
options.verification.providers.begin(),
|
||||
options.verification.providers.end(),
|
||||
provider) != options.verification.providers.end()) {
|
||||
|
||||
verification_providers_.push_back(provider);
|
||||
@ -118,14 +120,14 @@ void OperationProfiler::print_usage(std::ostream &out) const {
|
||||
size_t const kAliasStart = 10;
|
||||
|
||||
size_t columns = 0;
|
||||
|
||||
|
||||
std::string type_str = to_string(desc.type);
|
||||
columns += type_str.size();
|
||||
|
||||
out << " [" << type_str << "]";
|
||||
|
||||
if (columns < kAliasStart) {
|
||||
out << std::string(kAliasStart - columns, ' ');
|
||||
out << std::string(kAliasStart - columns, ' ');
|
||||
}
|
||||
|
||||
columns = 0;
|
||||
@ -161,7 +163,6 @@ bool OperationProfiler::satisfies(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t int_value;
|
||||
|
||||
if (arg_as_int(int_value, "inst_m", problem_space, problem)) {
|
||||
@ -252,14 +253,79 @@ bool OperationProfiler::satisfies(
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, library::Provider provider) {
|
||||
if (provider == library::Provider::kNone) {
|
||||
out << "kNone";
|
||||
}
|
||||
else if (provider == library::Provider::kCUTLASS) {
|
||||
out << "kCUTLASS";
|
||||
}
|
||||
else if (provider == library::Provider::kReferenceHost) {
|
||||
out << "kReferenceHost";
|
||||
}
|
||||
else if (provider == library::Provider::kReferenceDevice) {
|
||||
out << "kReferenceDevice";
|
||||
}
|
||||
else if (provider == library::Provider::kCUBLAS) {
|
||||
out << "kCUBLAS";
|
||||
}
|
||||
else if (provider == library::Provider::kCUDNN) {
|
||||
out << "kCUDNN";
|
||||
}
|
||||
else {
|
||||
out << "kInvalid";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, library::OperationKind provider) {
|
||||
if (provider == library::OperationKind::kGemm) {
|
||||
out << "kGemm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kRankK) {
|
||||
out << "kRankK";
|
||||
}
|
||||
else if (provider == library::OperationKind::kRank2K) {
|
||||
out << "kRank2K";
|
||||
}
|
||||
else if (provider == library::OperationKind::kTrmm) {
|
||||
out << "kTrmm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kSymm) {
|
||||
out << "kSymm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kConv2d) {
|
||||
out << "kConv2d";
|
||||
}
|
||||
else if (provider == library::OperationKind::kConv3d) {
|
||||
out << "kConv3d";
|
||||
}
|
||||
else if (provider == library::OperationKind::kEqGemm) {
|
||||
out << "kEqGemm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kSparseGemm) {
|
||||
out << "kSparseGemm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kReduction) {
|
||||
out << "kReduction";
|
||||
}
|
||||
else {
|
||||
out << "kInvalid";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
|
||||
/// Entry point to profile all operations in the manifest
|
||||
int OperationProfiler::profile_all(
|
||||
Options const &options,
|
||||
library::Manifest const &manifest,
|
||||
Options const &options,
|
||||
library::Manifest const &manifest,
|
||||
DeviceContext &device_context) {
|
||||
|
||||
ProblemSpace problem_space(arguments_, options.cmdline);
|
||||
|
||||
// 1. Construct performance report
|
||||
@ -282,13 +348,45 @@ int OperationProfiler::profile_all(
|
||||
for (auto const& operation_ptr : manifest) {
|
||||
|
||||
library::Operation const *operation = operation_ptr.get();
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " Operation: " << typeid(*operation).name() << "\n"
|
||||
<< " name: " << operation->description().name << "\n"
|
||||
<< " kind: " << operation->description().kind << "\n"
|
||||
<< " provider: " << operation->description().provider << "\n";
|
||||
#endif // CUTLASS_DEBUG_TRACE_LEVEL
|
||||
|
||||
auto min_cc = operation->description().tile_description.minimum_compute_capability;
|
||||
auto max_cc = operation->description().tile_description.maximum_compute_capability;
|
||||
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
std::cerr << " min_cc: " << min_cc << "\n";
|
||||
std::cerr << " max_cc: " << min_cc << "\n";
|
||||
#endif
|
||||
|
||||
// Clear named allocations
|
||||
device_context.free();
|
||||
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||
if (operation->description().kind != kind_) {
|
||||
std::cerr << " @ kind " << operation->description().kind
|
||||
<< " != kind_ " << kind_ << "\n";
|
||||
}
|
||||
if (operation->description().provider != library::Provider::kCUTLASS) {
|
||||
std::cerr << " @ provider " << operation->description().provider
|
||||
<< " != library::Provider::kCUTLASS\n";
|
||||
}
|
||||
if (options.device.compute_capability() < min_cc) {
|
||||
std::cerr << " @ compute_capability "
|
||||
<< options.device.compute_capability()
|
||||
<< " < min_cc " << min_cc << "\n";
|
||||
}
|
||||
if (options.device.compute_capability() > max_cc) {
|
||||
std::cerr << " @ compute_capability "
|
||||
<< options.device.compute_capability()
|
||||
<< " > max_cc " << max_cc << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
// Execute compatible cutlass operations if they satisfy the current device's compute capability
|
||||
if (operation->description().kind == kind_ &&
|
||||
operation->description().provider == library::Provider::kCUTLASS &&
|
||||
@ -296,17 +394,16 @@ int OperationProfiler::profile_all(
|
||||
options.device.compute_capability() <= max_cc) {
|
||||
|
||||
std::string operation_name(operation->description().name);
|
||||
|
||||
// Filter kernels by name
|
||||
bool filtered_by_name = options.operation_names.empty();
|
||||
if (!filtered_by_name) {
|
||||
|
||||
|
||||
for (auto const & op_name : options.operation_names) {
|
||||
if (find_string_matches_(op_name, operation_name)) {
|
||||
filtered_by_name = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto const & op_name : options.excluded_operation_names) {
|
||||
@ -333,10 +430,10 @@ int OperationProfiler::profile_all(
|
||||
problem);
|
||||
|
||||
if (status == Status::kErrorInternal) {
|
||||
|
||||
|
||||
// If there was an internal error, consume the CUDA error and move to the next operation.
|
||||
(void)cudaGetLastError();
|
||||
|
||||
|
||||
report.append_results(results_);
|
||||
continue;
|
||||
}
|
||||
@ -385,9 +482,9 @@ int OperationProfiler::profile_all(
|
||||
|
||||
continue_profiling = this->verify_cutlass(
|
||||
options,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
problem_space,
|
||||
problem);
|
||||
|
||||
@ -419,10 +516,10 @@ int OperationProfiler::profile_all(
|
||||
if (continue_profiling && options.profiling.enabled) {
|
||||
|
||||
continue_profiling = this->profile(
|
||||
options,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
options,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
problem_space,
|
||||
problem);
|
||||
}
|
||||
@ -459,7 +556,7 @@ void OperationProfiler::sleep(int sleep_duration) {
|
||||
SleepEx(sleep_duration, false);
|
||||
#else
|
||||
// sleep not supported
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -485,7 +582,7 @@ Disposition OperationProfiler::compare_tensors(
|
||||
|
||||
// bit-level equality
|
||||
passed = DeviceAllocation::block_compare_equal(
|
||||
experimental.type(),
|
||||
experimental.type(),
|
||||
experimental.data(),
|
||||
reference.data(),
|
||||
count);
|
||||
@ -494,7 +591,7 @@ Disposition OperationProfiler::compare_tensors(
|
||||
|
||||
// relative error function
|
||||
passed = DeviceAllocation::block_compare_relatively_equal(
|
||||
experimental.type(),
|
||||
experimental.type(),
|
||||
experimental.data(),
|
||||
reference.data(),
|
||||
count,
|
||||
@ -516,7 +613,7 @@ void OperationProfiler::save_workspace(
|
||||
for (auto const & named_allocation : device_context) {
|
||||
|
||||
DeviceAllocation *allocation = named_allocation.second;
|
||||
|
||||
|
||||
std::stringstream filename;
|
||||
|
||||
filename << desc.name << "_" << library::to_string(provider) << "_";
|
||||
@ -535,7 +632,7 @@ void OperationProfiler::save_workspace(
|
||||
if (options.report.verbose) {
|
||||
std::cout << "wrote '" << filename.str() << "'" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -575,7 +672,7 @@ Status OperationProfiler::profile_cutlass_(
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Initialize GPU timer
|
||||
//
|
||||
@ -590,7 +687,7 @@ Status OperationProfiler::profile_cutlass_(
|
||||
|
||||
int iteration = 0;
|
||||
for (; iteration < Iterations; ++iteration) {
|
||||
|
||||
|
||||
status = operation->run(
|
||||
arguments,
|
||||
host_workspace,
|
||||
@ -610,7 +707,7 @@ Status OperationProfiler::profile_cutlass_(
|
||||
//
|
||||
// Update performance result
|
||||
//
|
||||
|
||||
|
||||
runtime = timer.duration(iteration);
|
||||
|
||||
return status;
|
||||
@ -618,7 +715,7 @@ Status OperationProfiler::profile_cutlass_(
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Sets operation description
|
||||
/// Sets operation description
|
||||
void OperationProfiler::initialize_result_(
|
||||
PerformanceResult &result,
|
||||
library::OperationDescription const &operation_desc,
|
||||
@ -657,7 +754,7 @@ void OperationProfiler::set_argument(
|
||||
result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), value);
|
||||
}
|
||||
|
||||
void OperationProfiler::set_argument(
|
||||
void OperationProfiler::set_argument(
|
||||
PerformanceResult &result,
|
||||
char const *name,
|
||||
ProblemSpace const &problem_space,
|
||||
@ -669,12 +766,12 @@ void OperationProfiler::set_argument(
|
||||
|
||||
/// finds string matches filter_string in operation_name
|
||||
bool OperationProfiler::find_string_matches_(
|
||||
std::string const &filter_string,
|
||||
std::string const &filter_string,
|
||||
std::string const &operation_name) {
|
||||
// Returns true if all substrings appear in the operation_name in order
|
||||
|
||||
|
||||
// Split filter_string of the format "gemm*f32*nt" to tokens ["gemm", "f32", "nt"]
|
||||
std::string item;
|
||||
std::string item;
|
||||
std::istringstream iss(filter_string);
|
||||
std::vector<std::string> filter_tokens;
|
||||
while (std::getline(iss, item, '*')) {
|
||||
@ -692,7 +789,7 @@ bool OperationProfiler::find_string_matches_(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
start += (idx + token.length());
|
||||
start += (idx + token.length());
|
||||
}
|
||||
|
||||
// All tokens in filter_string found in operation_name
|
||||
|
||||
@ -234,7 +234,7 @@ Status SymmOperationProfiler::SymmProblem::parse(
|
||||
|
||||
/// Total number of bytes loaded
|
||||
int64_t SymmOperationProfiler::SymmProblem::bytes(library::SymmDescription const &operation_desc) const {
|
||||
int64_t bytes;
|
||||
int64_t bytes = 0;
|
||||
// Input bytes read and Output bytes written for the gemm problem
|
||||
// Half matrix including the diagonal will have (X*(X+1))/2 elements
|
||||
if (operation_desc.side_mode == SideMode::kLeft) {
|
||||
|
||||
@ -121,7 +121,7 @@ struct CommandLine {
|
||||
* Returns the commandline parameter for a given index (not including flags)
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(int index, value_t& val) const {
|
||||
void get_cmd_line_argument(size_t index, value_t& val) const {
|
||||
using namespace std;
|
||||
if (index < args.size()) {
|
||||
istringstream str_stream(args[index]);
|
||||
|
||||
@ -63,7 +63,7 @@ using ComplexDouble = cuda::std::complex<double>;
|
||||
// User could potentially define Half instead of cute::
|
||||
#ifndef BLAM_HALF_TYPE
|
||||
#define BLAM_HALF_TYPE 1
|
||||
#include <cute/numeric/half.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
namespace blam {
|
||||
using Half = cute::half_t;
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0,
|
||||
return;
|
||||
}
|
||||
|
||||
int total_elements = frag.size();
|
||||
int total_elements = int(frag.size());
|
||||
|
||||
if (M < 0 || M > total_elements) {
|
||||
if (thread_id == 0 && block_id == 0)
|
||||
|
||||
@ -42,8 +42,8 @@
|
||||
namespace cutlass {
|
||||
|
||||
__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
const float4 *weight,
|
||||
const int m, const int n, float epsilon) {
|
||||
const float4 *weight,
|
||||
const int m, const int n, float epsilon) {
|
||||
const int m_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int bdimx = blockDim.x;
|
||||
@ -115,9 +115,9 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
|
||||
template<typename T>
|
||||
__global__ void rmsnorm_twoPassAlgo_e1(T* output,
|
||||
const T* input,
|
||||
const T* weight,
|
||||
const int m, const int n,
|
||||
const T* input,
|
||||
const T* weight,
|
||||
const int m, const int n,
|
||||
float epsilon)
|
||||
{
|
||||
const int m_idx = blockIdx.x;
|
||||
@ -156,7 +156,7 @@ void rmsnorm(cutlass::MatrixCoord tensor_size,
|
||||
TensorRef<T, layout::RowMajor> ref_output,
|
||||
TensorRef<T, layout::RowMajor> ref_input,
|
||||
TensorRef<T, layout::RowMajor> ref_weight,
|
||||
cudaStream_t stream, float epsilon = 1e-5){
|
||||
cudaStream_t stream, float epsilon = 1e-5f){
|
||||
const int m = tensor_size.row();
|
||||
const int n = tensor_size.column();
|
||||
T* output = ref_output.data();
|
||||
|
||||
@ -112,9 +112,13 @@ public:
|
||||
/// Example
|
||||
/// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t;
|
||||
/// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t;
|
||||
static int const kBitsStoredVec = (sizeof_bits<Element>::value < 8) ? cutlass::lcm(static_cast<int>(sizeof_bits<Element>::value), 8) : sizeof_bits<Element>::value;
|
||||
static int const kElementsPerStoredVec = kBitsStoredVec / sizeof_bits<Element>::value;
|
||||
static int const kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8);
|
||||
static constexpr int kBitsStoredVec = (sizeof_bits<Element>::value < 8) ? cutlass::lcm(sizeof_bits<Element>::value, 8) : sizeof_bits<Element>::value;
|
||||
static constexpr int kElementsPerStoredVec = kBitsStoredVec / sizeof_bits<Element>::value;
|
||||
static constexpr int kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8);
|
||||
|
||||
static_assert(kBitsStoredVec != 0, "kBitsStoredVec can not be zero");
|
||||
static_assert(kElementsPerStoredVec != 0, "kElementsPerStoredVec can not be zero");
|
||||
static_assert(kNumStoragePerStoredVec != 0, "kNumStoragePerStoredVec can not be zero");
|
||||
|
||||
private:
|
||||
|
||||
|
||||
@ -108,4 +108,354 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Strides for convolutions
|
||||
|
||||
// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
|
||||
// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order
|
||||
// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout
|
||||
// right in KTRSC order and can be coalesced to just k.
|
||||
// We enforce this condition here with asserts.
|
||||
template <class IntT, size_t RankT_>
|
||||
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
|
||||
cute::array<int32_t, RankT_> shape_output,
|
||||
cute::array<IntT, RankT_> stride_output,
|
||||
cutlass::conv::Operator conv_op) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
static_assert(RankT_ >= 3u);
|
||||
constexpr static int RankT = static_cast<int>(RankT_);
|
||||
|
||||
assert(stride_output[RankT-1] == 1);
|
||||
cute::for_each(cute::make_seq<RankT-2>{}, [&](auto i) {
|
||||
assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]);
|
||||
});
|
||||
|
||||
auto s_copy = s;
|
||||
cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ?
|
||||
stride_output[0] :
|
||||
stride_output[RankT-2];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
//
|
||||
// Activation tensor ((w, h, d, n), _1) for fprop kernel
|
||||
//
|
||||
|
||||
// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
|
||||
cute::array<IntT, 3> stride_nwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
assert(stride_nwc[2] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_nwc[1];
|
||||
cute::get<0,1>(s_copy) = stride_nwc[0];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
|
||||
cute::array<IntT, 4> stride_nhwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
assert(stride_nhwc[3] == 1);
|
||||
auto s_copy = s;
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<0,i>(s_copy) = stride_nhwc[2-i];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
|
||||
cute::array<IntT, 5> stride_ndhwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ndhwc[4] == 1);
|
||||
auto s_copy = s;
|
||||
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
||||
cute::get<0,i>(s_copy) = stride_ndhwc[3-i];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
//
|
||||
// Filter tensor (k, (_1, s, r, t)) for fprop kernel
|
||||
//
|
||||
|
||||
// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
|
||||
cute::array<IntT, 3> stride_ksc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ksc[2] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_ksc[0];
|
||||
cute::get<1,1>(s_copy) = stride_ksc[1];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
|
||||
cute::array<IntT, 4> stride_krsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_krsc[3] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_krsc[0];
|
||||
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
||||
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
|
||||
cute::array<IntT, 5> stride_ktrsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ktrsc[4] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
//
|
||||
// Activation tensor (_1, (w, h, d, n)) for wgrad kernel
|
||||
//
|
||||
// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel
|
||||
//
|
||||
|
||||
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
|
||||
cute::array<IntT, 3> stride_nwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_nwc[2] == 1);
|
||||
auto s_copy = s;
|
||||
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
||||
cute::get<1,0>(s_copy) = stride_nwc[1];
|
||||
cute::get<1,1>(s_copy) = stride_nwc[0];
|
||||
}
|
||||
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
||||
// stride_nwc in dgrad is ksc.
|
||||
cute::get<1,0>(s_copy) = stride_nwc[0];
|
||||
cute::get<1,1>(s_copy) = stride_nwc[1];
|
||||
}
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
|
||||
cute::array<IntT, 4> stride_nhwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_nhwc[3] == 1);
|
||||
auto s_copy = s;
|
||||
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<1,i>(s_copy) = stride_nhwc[2-i];
|
||||
});
|
||||
}
|
||||
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
||||
// stride_nhwc in dgrad is krsc.
|
||||
cute::get<1,0>(s_copy) = stride_nhwc[0];
|
||||
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
||||
cute::get<1,2-i>(s_copy) = stride_nhwc[i+1];
|
||||
});
|
||||
}
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
|
||||
cute::array<IntT, 5> stride_ndhwc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ndhwc[4] == 1);
|
||||
auto s_copy = s;
|
||||
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
||||
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
||||
cute::get<1,i>(s_copy) = stride_ndhwc[3-i];
|
||||
});
|
||||
}
|
||||
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
||||
// stride_ndhwc in dgrad is ktrsc.
|
||||
cute::get<1,0>(s_copy) = stride_ndhwc[0];
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1];
|
||||
});
|
||||
}
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
//
|
||||
// NZPQ tensor (_1, nzpq) for wgrad kernel
|
||||
//
|
||||
|
||||
// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
cute::array<IntT, 3> stride_nqk,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_nqk[2] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1>(s_copy) = stride_nqk[1];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
cute::array<IntT, 4> stride_npqk,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_npqk[3] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1>(s_copy) = stride_npqk[2];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
cute::array<IntT, 5> stride_nzpqk,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_nzpqk[4] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<1>(s_copy) = stride_nzpqk[3];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Wgrad output tensor (k, (_1, s, r, t), _0)
|
||||
//
|
||||
|
||||
// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
||||
cute::array<IntT, 3> stride_ksc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ksc[2] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_ksc[0];
|
||||
cute::get<1,1>(s_copy) = stride_ksc[1];
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
||||
cute::array<IntT, 4> stride_krsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_krsc[3] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_krsc[0];
|
||||
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
||||
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
|
||||
template <class IntT>
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
|
||||
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
||||
cute::array<IntT, 5> stride_ktrsc,
|
||||
conv::Operator ConvOp) {
|
||||
static_assert(std::is_integral_v<IntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
|
||||
assert(stride_ktrsc[4] == 1);
|
||||
auto s_copy = s;
|
||||
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
||||
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
||||
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
||||
});
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -40,8 +40,9 @@
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/numeric/half.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
#include <cute/numeric/complex.hpp>
|
||||
|
||||
#include <cutlass/layout/layout.h>
|
||||
|
||||
// The computed infinity norm does not include
|
||||
@ -233,7 +234,8 @@ print_relative_error(
|
||||
T1 const& data,
|
||||
T2 const& reference,
|
||||
bool print_verbose = false,
|
||||
bool print_error = true) {
|
||||
bool print_error = true,
|
||||
double error_margin = 0.00001) {
|
||||
using std::abs; using std::sqrt;
|
||||
|
||||
// Use either double or complex<double> for error computation
|
||||
@ -252,8 +254,8 @@ print_relative_error(
|
||||
double tot_norm_sq = 0;
|
||||
double tot_ind_rel_err = 0;
|
||||
double max_ind_rel_err = 0;
|
||||
for (std::size_t i = 0; i < n; ++i)
|
||||
{
|
||||
double max_diff = 0;
|
||||
for (std::size_t i = 0; i < n; ++i) {
|
||||
error_type val = data[i];
|
||||
error_type ref = reference[i];
|
||||
|
||||
@ -267,6 +269,9 @@ print_relative_error(
|
||||
// Maximum relative error
|
||||
max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
|
||||
|
||||
// Maximum delta in value error
|
||||
max_diff = std::max(max_diff, diff);
|
||||
|
||||
// Total relative error
|
||||
tot_error_sq += diff * diff;
|
||||
tot_norm_sq += aref * aref;
|
||||
@ -276,18 +281,40 @@ print_relative_error(
|
||||
}
|
||||
}
|
||||
|
||||
printf("Vector reference norm: [%.5e]\n", sqrt(tot_norm_sq));
|
||||
double ave_rel_err = tot_ind_rel_err / double(n);
|
||||
if (print_error) {
|
||||
printf("Average relative error: %.3e\n", ave_rel_err);
|
||||
}
|
||||
|
||||
if (print_error) {
|
||||
printf("Maximum relative error: %.3e\n", max_ind_rel_err);
|
||||
}
|
||||
|
||||
if (print_error) {
|
||||
printf("Maximum difference : %.3e\n", max_diff);
|
||||
}
|
||||
|
||||
double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
|
||||
if (print_error)
|
||||
printf("Vector relative error: [%.5e]\n", tot_rel_err);
|
||||
if (print_error) {
|
||||
printf("Vector relative error: %.3e\n", tot_rel_err);
|
||||
}
|
||||
|
||||
double ave_rel_err = tot_ind_rel_err / double(n);
|
||||
if (print_error)
|
||||
printf("Average relative error: [%.5e]\n", ave_rel_err);
|
||||
printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq));
|
||||
|
||||
if (print_error)
|
||||
printf("Maximum relative error: [%.5e]\n", max_ind_rel_err);
|
||||
|
||||
return (tot_rel_err == 0.0) ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// Overload for cute::Tensor<>
|
||||
template <class Engine, class Layout>
|
||||
int
|
||||
print_relative_error(
|
||||
cute::Tensor<Engine, Layout> data,
|
||||
cute::Tensor<Engine, Layout> reference,
|
||||
bool print_verbose = false,
|
||||
bool print_error = true,
|
||||
double error_margin = 0.00001) {
|
||||
assert(size(data) == size(reference));
|
||||
return print_relative_error(static_cast<std::size_t>(size(data)),
|
||||
data, reference,
|
||||
print_verbose, print_error, error_margin);
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ template <int Rank>
|
||||
struct LinearToCoordinateHelper<Rank, 0> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
|
||||
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &) const {
|
||||
coord[Rank - 1] = int(idx);
|
||||
}
|
||||
};
|
||||
|
||||
@ -134,9 +134,8 @@ struct RandomGaussianFunc {
|
||||
stddev(static_cast<FloatType>(stddev_)),
|
||||
int_scale(int_scale_) {
|
||||
|
||||
float_scale_up = FloatType(IntType(1) << int_scale);
|
||||
float_scale_up += FloatType(0.5) * float_scale_up;
|
||||
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
||||
float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits
|
||||
float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale);
|
||||
}
|
||||
};
|
||||
|
||||
@ -172,8 +171,8 @@ struct RandomGaussianFunc {
|
||||
|
||||
Element result;
|
||||
if (params.int_scale >= 0) {
|
||||
rnd = FloatType(IntType(rnd * params.float_scale_up));
|
||||
result = Element(rnd * params.float_scale_down);
|
||||
rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up)));
|
||||
result = Element(IntType(rnd * params.float_scale_down));
|
||||
}
|
||||
else {
|
||||
result = Element(rnd);
|
||||
@ -448,9 +447,8 @@ struct RandomUniformFunc {
|
||||
max(static_cast<FloatType>(max_)),
|
||||
int_scale(int_scale_) {
|
||||
|
||||
float_scale_up = FloatType(IntType(1) << int_scale);
|
||||
float_scale_up += FloatType(0.5) * float_scale_up;
|
||||
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
||||
float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits
|
||||
float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale);
|
||||
}
|
||||
};
|
||||
|
||||
@ -489,8 +487,8 @@ struct RandomUniformFunc {
|
||||
Element result;
|
||||
|
||||
if (params.int_scale >= 0) {
|
||||
rnd = FloatType(IntType(rnd * params.float_scale_up));
|
||||
result = Element(rnd * params.float_scale_down);
|
||||
rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up)));
|
||||
result = Element(IntType(rnd * params.float_scale_down));
|
||||
}
|
||||
else {
|
||||
result = Element(rnd);
|
||||
@ -774,9 +772,13 @@ struct RandomSparseMetaFunc {
|
||||
MetaSizeInBits(MetaSizeInBits_) {
|
||||
if (MetaSizeInBits_ == 2) {
|
||||
range = 6;
|
||||
} else if (MetaSizeInBits_ == 4) {
|
||||
}
|
||||
else if (MetaSizeInBits_ == 4) {
|
||||
range = 2;
|
||||
}
|
||||
else {
|
||||
throw std::invalid_argument("Invalid MetaSizeInBits");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -1161,34 +1163,10 @@ struct TensorClearPartialFunc {
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
Element element;
|
||||
FillMode fill_mode;
|
||||
int alignment;
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): fill_mode(FillMode::kNone) { }
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
Params(
|
||||
TensorView view_,
|
||||
Element element_,
|
||||
FillMode fill_mode_,
|
||||
int alignment_
|
||||
):
|
||||
view(view_), element(element_), fill_mode(fill_mode_), alignment(alignment_) {
|
||||
|
||||
}
|
||||
TensorView view{};
|
||||
Element element{};
|
||||
FillMode fill_mode{FillMode::kNone};
|
||||
int alignment{0};
|
||||
};
|
||||
|
||||
//
|
||||
@ -1307,7 +1285,7 @@ void TensorClearPartial(
|
||||
|
||||
TensorForEach<Func, Layout::kRank, Params>(
|
||||
view.extent(),
|
||||
Params(view, element, fill_mode, alignment),
|
||||
Params{view, element, fill_mode, alignment},
|
||||
/*grid_size*/0, /*block_size*/0,
|
||||
stream
|
||||
);
|
||||
|
||||
@ -120,7 +120,7 @@ __global__ void TensorTransformReducePartial(
|
||||
ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
||||
|
||||
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int64_t size = view_A.size();
|
||||
auto size = static_cast<int64_t>(view_A.size());
|
||||
|
||||
__shared__ ComputeType scratchpad[kBlockSize];
|
||||
|
||||
|
||||
649
tools/util/include/cutlass/util/reference/host/conv.hpp
Normal file
649
tools/util/include/cutlass/util/reference/host/conv.hpp
Normal file
@ -0,0 +1,649 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for CONV in host-side code.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class EngineAct, class LayoutAct>
|
||||
bool
|
||||
is_activation_in_bounds(
|
||||
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
||||
int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_) {
|
||||
return ((n_ >= 0 && n_ < size<4>(activation)) &&
|
||||
(d_ >= 0 && d_ < size<3>(activation)) &&
|
||||
(h_ >= 0 && h_ < size<2>(activation)) &&
|
||||
(w_ >= 0 && w_ < size<1>(activation)) &&
|
||||
(c_ >= 0 && c_ < size<0>(activation)));
|
||||
}
|
||||
|
||||
template<class EngineAct, class LayoutAct>
|
||||
bool
|
||||
is_activation_in_bounds(
|
||||
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
||||
int32_t n_, int32_t h_, int32_t w_, int32_t c_) {
|
||||
return ((n_ >= 0 && n_ < size<3>(activation)) &&
|
||||
(h_ >= 0 && h_ < size<2>(activation)) &&
|
||||
(w_ >= 0 && w_ < size<1>(activation)) &&
|
||||
(c_ >= 0 && c_ < size<0>(activation)));
|
||||
}
|
||||
|
||||
template<class EngineAct, class LayoutAct>
|
||||
bool
|
||||
is_activation_in_bounds(
|
||||
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
||||
int32_t n_, int32_t w_, int32_t c_) {
|
||||
return ((n_ >= 0 && n_ < size<2>(activation)) &&
|
||||
(w_ >= 0 && w_ < size<1>(activation)) &&
|
||||
(c_ >= 0 && c_ < size<0>(activation)));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template<
|
||||
class ElementAcc_,
|
||||
class ElementScalar_,
|
||||
class ElementCompute_,
|
||||
class ElementC_,
|
||||
class ElementOut_,
|
||||
class TensorAlpha_,
|
||||
class TensorBeta_,
|
||||
class TensorBias_,
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>>
|
||||
struct ConvEpilogueFusionParams {
|
||||
using ElementAcc = ElementAcc_;
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementOut = ElementOut_;
|
||||
using TensorAlpha = TensorAlpha_;
|
||||
using TensorBeta = TensorBeta_;
|
||||
using TensorBias = TensorBias_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorAlpha tensor_alpha{};
|
||||
TensorBeta tensor_beta{};
|
||||
TensorBias tensor_bias{};
|
||||
};
|
||||
|
||||
template<
|
||||
cutlass::conv::Operator ConvOp,
|
||||
int NumSpatialDims,
|
||||
class TensorA,
|
||||
class TensorB,
|
||||
class TensorC,
|
||||
class TensorD,
|
||||
class ShapePadding,
|
||||
class StrideTraversal,
|
||||
class ShapeDilation,
|
||||
class EpilogueFusionParams>
|
||||
struct ConvReferenceImpl {
|
||||
using ElementAcc = typename EpilogueFusionParams::ElementAcc;
|
||||
using ElementC = typename EpilogueFusionParams::ElementC;
|
||||
using ElementOut = typename EpilogueFusionParams::ElementOut;
|
||||
using ElementScalar = typename EpilogueFusionParams::ElementScalar;
|
||||
using ElementCompute = typename EpilogueFusionParams::ElementCompute;
|
||||
using ElementBias = typename EpilogueFusionParams::TensorBias::value_type;
|
||||
using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAcc> acc_converter;
|
||||
NumericConverter<ElementCompute, ElementC> residual_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
// Output related converter
|
||||
NumericConverter<ElementOut, ElementCompute> output_converter;
|
||||
|
||||
EpilogueFusionParams& epi_fusion_params_;
|
||||
|
||||
TensorA const& tensor_a_;
|
||||
TensorB const& tensor_b_;
|
||||
TensorC const& tensor_c_;
|
||||
TensorD& tensor_d_;
|
||||
|
||||
ShapePadding const& padding_;
|
||||
StrideTraversal const& tstride_;
|
||||
ShapeDilation const& dilation_;
|
||||
|
||||
// Epilogue activation operation
|
||||
ActivationFunctor epi_activation;
|
||||
ConvReferenceImpl(
|
||||
TensorA const& tensor_a,
|
||||
TensorB const& tensor_b,
|
||||
TensorC const& tensor_c,
|
||||
TensorD& tensor_d,
|
||||
ShapePadding const& padding,
|
||||
StrideTraversal const& tstride,
|
||||
ShapeDilation const& dilation,
|
||||
EpilogueFusionParams& epi_fusion_params)
|
||||
: tensor_a_(tensor_a),
|
||||
tensor_b_(tensor_b),
|
||||
tensor_c_(tensor_c),
|
||||
tensor_d_(tensor_d),
|
||||
padding_(padding),
|
||||
tstride_(tstride),
|
||||
dilation_(dilation),
|
||||
epi_fusion_params_(epi_fusion_params) {
|
||||
static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
|
||||
static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
|
||||
}
|
||||
|
||||
void compute_reference() {
|
||||
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
|
||||
fprop_reference(cute::Int<NumSpatialDims>{});
|
||||
}
|
||||
else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
|
||||
dgrad_reference(cute::Int<NumSpatialDims>{});
|
||||
}
|
||||
else {
|
||||
wgrad_reference(cute::Int<NumSpatialDims>{});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Specialization for 1D fprop kernel
|
||||
void fprop_reference(cute::Int<1> spatial_dims) {
|
||||
int32_t N = size<2>(tensor_d_);
|
||||
int32_t Q = size<1>(tensor_d_);
|
||||
int32_t K = size<0>(tensor_d_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
int32_t C = size<0>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, n) * tensor_b_(c, s, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(k, q, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(k, q, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 2D fprop kernel
|
||||
void fprop_reference(cute::Int<2> spatial_dims) {
|
||||
int32_t N = size<3>(tensor_d_);
|
||||
int32_t P = size<2>(tensor_d_);
|
||||
int32_t Q = size<1>(tensor_d_);
|
||||
int32_t K = size<0>(tensor_d_);
|
||||
int32_t R = size<2>(tensor_b_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
int32_t C = size<0>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t p = 0; p < P; ++p) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, h, n) * tensor_b_(c, s, r, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(k, q, p, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 3D fprop kernel
|
||||
void fprop_reference(cute::Int<3> spatial_dims) {
|
||||
int32_t N = size<4>(tensor_d_);
|
||||
int32_t Z = size<3>(tensor_d_);
|
||||
int32_t P = size<2>(tensor_d_);
|
||||
int32_t Q = size<1>(tensor_d_);
|
||||
int32_t K = size<0>(tensor_d_);
|
||||
int32_t T = size<3>(tensor_b_);
|
||||
int32_t R = size<2>(tensor_b_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
int32_t C = size<0>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t z = 0; z < Z; ++z) {
|
||||
for (int32_t p = 0; p < P; ++p) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t t = 0; t < T; ++t) {
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_a_(c, w, h, d, n) * tensor_b_(c, s, r, t, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(k, q, p, z, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 1D dgrad kernel
|
||||
void dgrad_reference(cute::Int<1> spatial_dims) {
|
||||
int32_t N = size<2>(tensor_d_);
|
||||
int32_t W = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
int32_t K = size<2>(tensor_b_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t w = 0; w < W; ++w) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
||||
|
||||
if (q % cute::get<0>(tstride_) == 0) {
|
||||
q /= cute::get<0>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, q, k)) {
|
||||
accumulator += ElementAcc(tensor_a_(k, q, n) * tensor_b_(c, s, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
||||
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
||||
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, w, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(c, w, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 2D dgrad kernel
|
||||
void dgrad_reference(cute::Int<2> spatial_dims) {
|
||||
int32_t N = size<3>(tensor_d_);
|
||||
int32_t H = size<2>(tensor_d_);
|
||||
int32_t W = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
int32_t K = size<3>(tensor_b_);
|
||||
int32_t R = size<2>(tensor_b_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t h = 0; h < H; ++h) {
|
||||
for (int32_t w = 0; w < W; ++w) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
||||
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
||||
|
||||
if (q % cute::get<0>(tstride_) == 0) {
|
||||
q /= cute::get<0>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (p % cute::get<1>(tstride_) == 0) {
|
||||
p /= cute::get<1>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k)) {
|
||||
accumulator += ElementAcc(tensor_a_(k, q, p, n) * tensor_b_(c, s, r, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
||||
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
||||
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
||||
}
|
||||
tensor_d_(c, w, h, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 3D dgrad kernel
|
||||
void dgrad_reference(cute::Int<3> spatial_dims) {
|
||||
int32_t N = size<4>(tensor_d_);
|
||||
int32_t D = size<3>(tensor_d_);
|
||||
int32_t H = size<2>(tensor_d_);
|
||||
int32_t W = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
int32_t K = size<4>(tensor_b_);
|
||||
int32_t T = size<3>(tensor_b_);
|
||||
int32_t R = size<2>(tensor_b_);
|
||||
int32_t S = size<1>(tensor_b_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t d = 0; d < D; ++d) {
|
||||
for (int32_t h = 0; h < H; ++h) {
|
||||
for (int32_t w = 0; w < W; ++w) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
for (int32_t t = 0; t < T; ++t) {
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
||||
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
||||
int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_);
|
||||
|
||||
if (q % cute::get<0>(tstride_) == 0) {
|
||||
q /= cute::get<0>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (p % cute::get<1>(tstride_) == 0) {
|
||||
p /= cute::get<1>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (z % cute::get<2>(tstride_) == 0) {
|
||||
z /= cute::get<2>(tstride_);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k)) {
|
||||
accumulator += ElementAcc(tensor_a_(k, q, p, z, n) * tensor_b_(c, s, r, t, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
||||
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
||||
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(c, w, h, d, n) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 1D wgrad kernel
|
||||
void wgrad_reference(cute::Int<1> spatial_dims) {
|
||||
int32_t N = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, n) * tensor_a_(k, q, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, s, k));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(c, s, k) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 2D wgrad kernel
|
||||
void wgrad_reference(cute::Int<2> spatial_dims) {
|
||||
int32_t N = size<3>(tensor_a_);
|
||||
int32_t P = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t R = size<2>(tensor_d_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t p = 0; p < P; ++p) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, h, n) * tensor_a_(k, q, p, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(c, s, r, k) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for 3D wgrad kernel
|
||||
void wgrad_reference(cute::Int<3> spatial_dims) {
|
||||
int32_t N = size<4>(tensor_a_);
|
||||
int32_t Z = size<3>(tensor_a_);
|
||||
int32_t P = size<2>(tensor_a_);
|
||||
int32_t Q = size<1>(tensor_a_);
|
||||
int32_t K = size<0>(tensor_a_);
|
||||
int32_t T = size<3>(tensor_d_);
|
||||
int32_t R = size<2>(tensor_d_);
|
||||
int32_t S = size<1>(tensor_d_);
|
||||
int32_t C = size<0>(tensor_d_);
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int32_t k = 0; k < K; ++k) {
|
||||
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
||||
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
||||
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
||||
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
||||
for (int32_t t = 0; t < T; ++t) {
|
||||
for (int32_t r = 0; r < R; ++r) {
|
||||
for (int32_t s = 0; s < S; ++s) {
|
||||
for (int32_t c = 0; c < C; ++c) {
|
||||
auto accumulator = ElementAcc(0);
|
||||
for (int32_t n = 0; n < N; ++n) {
|
||||
for (int32_t z = 0; z < Z; ++z) {
|
||||
for (int32_t p = 0; p < P; ++p) {
|
||||
for (int32_t q = 0; q < Q; ++q) {
|
||||
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
||||
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
||||
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
||||
if (detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c)) {
|
||||
accumulator += ElementAcc(tensor_b_(c, w, h, d, n) * tensor_a_(k, q, p, z, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) +
|
||||
scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k));
|
||||
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
||||
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
||||
}
|
||||
output = epi_activation(output);
|
||||
tensor_d_(c, s, r, t, k) = output_converter(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -236,6 +236,7 @@ void gett_mainloop(
|
||||
acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
|
||||
// Cutlass includes
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -196,7 +197,7 @@ struct RandomGaussianFunc {
|
||||
// Sample from the Gaussian distribution for a nonzero element
|
||||
if (bernoulli_result) {
|
||||
if (int_scale >= 0) {
|
||||
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
||||
rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
||||
result = static_cast<Element>(rnd);
|
||||
}
|
||||
else {
|
||||
@ -567,7 +568,7 @@ struct RandomUniformFunc {
|
||||
// testing
|
||||
Element result;
|
||||
if (int_scale >= 0) {
|
||||
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
||||
rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
||||
result = static_cast<Element>(Real(rnd));
|
||||
}
|
||||
else {
|
||||
@ -1381,9 +1382,13 @@ struct RandomSparseMetaFunc {
|
||||
std::srand((unsigned)seed);
|
||||
if (MetaSizeInBits_ == 2) {
|
||||
range = 6;
|
||||
} else if (MetaSizeInBits_ == 4) {
|
||||
}
|
||||
else if (MetaSizeInBits_ == 4) {
|
||||
range = 2;
|
||||
}
|
||||
else {
|
||||
throw std::invalid_argument("Invalid MetaSizeInBits");
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
|
||||
@ -61,7 +61,7 @@ ComputeType TensorTransformReduce(
|
||||
TransformOp transform
|
||||
) {
|
||||
|
||||
for (int64_t idx = 0; idx < view.size(); ++idx) {
|
||||
for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) {
|
||||
typename Layout::TensorCoord coord;
|
||||
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
|
||||
|
||||
@ -94,7 +94,7 @@ ComputeType TensorTransformReduce(
|
||||
throw std::runtime_error("Tensor extents must match.");
|
||||
}
|
||||
|
||||
for (int64_t idx = 0; idx < view_A.size(); ++idx) {
|
||||
for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) {
|
||||
|
||||
typename Layout::TensorCoord coord;
|
||||
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
|
||||
|
||||
Reference in New Issue
Block a user