CUTLASS 2.0 (#62)

CUTLASS 2.0

Substantially refactored for

- Better performance, particularly for native Turing Tensor Cores
- Robust and durable templates spanning the design space
- Encapsulated functionality embodying modern C++11 programming techniques
- Optimized containers and data types for efficient, generic, portable device code

Updates to:
- Quick start guide
- Documentation
- Utilities
- CUTLASS Profiler

Native Turing Tensor Cores
- Efficient GEMM kernels targeting Turing Tensor Cores
- Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands

Coverage of existing CUTLASS functionality:
- GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
- Volta Tensor Cores through native mma.sync and through WMMA API
- Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
- Batched GEMM operations
- Complex-valued GEMMs

Note: this commit and all that follow require a host compiler supporting C++11 or greater.
This commit is contained in:
Andrew Kerr
2019-11-19 16:55:34 -08:00
committed by GitHub
parent b5cab177a9
commit fb335f6a5f
5434 changed files with 599799 additions and 250176 deletions

View File

@ -40,6 +40,11 @@
Aside from defining and launching the SGEMM kernel, this example does not use any other components
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
prevalent in the CUTLASS unit tests.
This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to
highlight the minimum amount of differences needed to transition to cutlass-2.0.
Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu
*/
// Standard Library includes
@ -47,17 +52,15 @@
#include <sstream>
#include <vector>
// Helper methods to check for errors
#include "helper.h"
//
// CUTLASS includes needed for single-precision GEMM kernel
//
// Defines cutlass::gemm::Gemm, the generic Gemm computation template class.
#include "cutlass/gemm/gemm.h"
// Defines cutlass::gemm::SgemmTraits, the structural components for single-precision GEMM
#include "cutlass/gemm/sgemm_traits.h"
#pragma warning( disable : 4503)
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
#include "cutlass/gemm/device/gemm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
//
@ -81,63 +84,58 @@ cudaError_t CutlassSgemmNN(
int ldc) {
// Define type definition for single-precision CUTLASS GEMM with column-major
// input matrices and 128x128x8 threadblock tile size.
//
// Note, GemmTraits<> is a generic template defined for various general matrix product
// computations within CUTLASS. It is intended to be maximally flexible, and consequently
// it contains numerous template arguments.
// input matrices and 128x128x8 threadblock tile size (chosen by default).
//
// To keep the interface manageable, several helpers are defined for plausible compositions
// including the following example for single-precision GEMM. Typical values are used as
// default template arguments. See `cutlass/gemm/gemm_traits.h` for more details.
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
//
typedef cutlass::gemm::SgemmTraits<
cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
cutlass::Shape<8, 128, 128> // threadblock tile size
>
GemmTraits;
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
// Define a CUTLASS GEMM type from a GemmTraits<> instantiation.
typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
using ColumnMajor = cutlass::layout::ColumnMajor;
// Construct and initialize CUTLASS GEMM parameters object.
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
ColumnMajor, // Layout of A matrix
float, // Data-type of B matrix
ColumnMajor, // Layout of B matrix
float, // Data-type of C matrix
ColumnMajor>; // Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;
// Construct the CUTLASS GEMM arguments object.
//
// One of CUTLASS's design patterns is to define parameters objects that are constructible
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
// and other arguments needed by Gemm and its components.
//
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
//
typename Gemm::Params params;
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix)
{alpha, beta}); // Scalars used in the Epilogue
int result = params.initialize(
M, // GEMM M dimension
N, // GEMM N dimension
K, // GEMM K dimension
alpha, // scalar alpha
A, // matrix A operand
lda,
B, // matrix B operand
ldb,
beta, // scalar beta
C, // source matrix C
ldc,
C, // destination matrix C (may be different memory than source C matrix)
ldc
);
//
// Launch the CUTLASS GEMM kernel.
//
cutlass::Status status = gemm_operator(args);
if (result) {
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
return cudaErrorInvalidValue;
//
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
//
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
// Launch the CUTLASS GEMM kernel.
Gemm::launch(params);
// Return any errors associated with the launch or cudaSuccess if no error.
return cudaGetLastError();
// Return success, if no errors were encountered.
return cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////////////////////////