CUTLASS 2.0 (#62)
CUTLASS 2.0 Substantially refactored for - Better performance, particularly for native Turing Tensor Cores - Robust and durable templates spanning the design space - Encapsulated functionality embodying modern C++11 programming techniques - Optimized containers and data types for efficient, generic, portable device code Updates to: - Quick start guide - Documentation - Utilities - CUTLASS Profiler Native Turing Tensor Cores - Efficient GEMM kernels targeting Turing Tensor Cores - Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands Coverage of existing CUTLASS functionality: - GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs - Volta Tensor Cores through native mma.sync and through WMMA API - Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions - Batched GEMM operations - Complex-valued GEMMs Note: this commit and all that follow require a host compiler supporting C++11 or greater.
This commit is contained in:
@ -40,6 +40,11 @@
|
||||
Aside from defining and launching the SGEMM kernel, this example does not use any other components
|
||||
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
|
||||
prevalent in the CUTLASS unit tests.
|
||||
|
||||
This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to
|
||||
highlight the minimum amount of differences needed to transition to cutlass-2.0.
|
||||
|
||||
Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
@ -47,17 +52,15 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// Helper methods to check for errors
|
||||
#include "helper.h"
|
||||
|
||||
//
|
||||
// CUTLASS includes needed for single-precision GEMM kernel
|
||||
//
|
||||
|
||||
// Defines cutlass::gemm::Gemm, the generic Gemm computation template class.
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
// Defines cutlass::gemm::SgemmTraits, the structural components for single-precision GEMM
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
|
||||
#pragma warning( disable : 4503)
|
||||
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -81,63 +84,58 @@ cudaError_t CutlassSgemmNN(
|
||||
int ldc) {
|
||||
|
||||
// Define type definition for single-precision CUTLASS GEMM with column-major
|
||||
// input matrices and 128x128x8 threadblock tile size.
|
||||
//
|
||||
// Note, GemmTraits<> is a generic template defined for various general matrix product
|
||||
// computations within CUTLASS. It is intended to be maximally flexible, and consequently
|
||||
// it contains numerous template arguments.
|
||||
// input matrices and 128x128x8 threadblock tile size (chosen by default).
|
||||
//
|
||||
// To keep the interface manageable, several helpers are defined for plausible compositions
|
||||
// including the following example for single-precision GEMM. Typical values are used as
|
||||
// default template arguments. See `cutlass/gemm/gemm_traits.h` for more details.
|
||||
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
|
||||
//
|
||||
typedef cutlass::gemm::SgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
|
||||
cutlass::Shape<8, 128, 128> // threadblock tile size
|
||||
>
|
||||
GemmTraits;
|
||||
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
|
||||
|
||||
// Define a CUTLASS GEMM type from a GemmTraits<> instantiation.
|
||||
typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
|
||||
using ColumnMajor = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Construct and initialize CUTLASS GEMM parameters object.
|
||||
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
|
||||
ColumnMajor, // Layout of A matrix
|
||||
float, // Data-type of B matrix
|
||||
ColumnMajor, // Layout of B matrix
|
||||
float, // Data-type of C matrix
|
||||
ColumnMajor>; // Layout of C matrix
|
||||
|
||||
// Define a CUTLASS GEMM type
|
||||
CutlassGemm gemm_operator;
|
||||
|
||||
// Construct the CUTLASS GEMM arguments object.
|
||||
//
|
||||
// One of CUTLASS's design patterns is to define parameters objects that are constructible
|
||||
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
|
||||
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
|
||||
// and other arguments needed by Gemm and its components.
|
||||
//
|
||||
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
|
||||
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
|
||||
//
|
||||
typename Gemm::Params params;
|
||||
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
|
||||
{A, lda}, // Tensor-ref for source matrix A
|
||||
{B, ldb}, // Tensor-ref for source matrix B
|
||||
{C, ldc}, // Tensor-ref for source matrix C
|
||||
{C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix)
|
||||
{alpha, beta}); // Scalars used in the Epilogue
|
||||
|
||||
int result = params.initialize(
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
alpha, // scalar alpha
|
||||
A, // matrix A operand
|
||||
lda,
|
||||
B, // matrix B operand
|
||||
ldb,
|
||||
beta, // scalar beta
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc
|
||||
);
|
||||
//
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
//
|
||||
|
||||
cutlass::Status status = gemm_operator(args);
|
||||
|
||||
if (result) {
|
||||
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
//
|
||||
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
|
||||
//
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
Gemm::launch(params);
|
||||
|
||||
// Return any errors associated with the launch or cudaSuccess if no error.
|
||||
return cudaGetLastError();
|
||||
// Return success, if no errors were encountered.
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user