CUTLASS 2.0 (#62)
CUTLASS 2.0 Substantially refactored for - Better performance, particularly for native Turing Tensor Cores - Robust and durable templates spanning the design space - Encapsulated functionality embodying modern C++11 programming techniques - Optimized containers and data types for efficient, generic, portable device code Updates to: - Quick start guide - Documentation - Utilities - CUTLASS Profiler Native Turing Tensor Cores - Efficient GEMM kernels targeting Turing Tensor Cores - Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands Coverage of existing CUTLASS functionality: - GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs - Volta Tensor Cores through native mma.sync and through WMMA API - Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions - Batched GEMM operations - Complex-valued GEMMs Note: this commit and all that follow require a host compiler supporting C++11 or greater.
This commit is contained in:
@ -24,225 +24,193 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates how to use the TileIterator in CUTLASS to load data from addressable
|
||||
memory, and store it back into addressable memory.
|
||||
This example demonstrates how to use the PredicatedTileIterator in CUTLASS to load data from
|
||||
addressable memory, and then store it back into addressable memory.
|
||||
|
||||
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data from
|
||||
and to addressable memory. The TileIterator accepts a TileTraits type, which defines the shape of a
|
||||
tile and the distribution of accesses by individual entities, either threads or others.
|
||||
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
|
||||
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
|
||||
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
|
||||
thread mappings to be specified.
|
||||
|
||||
In this example, a LoadTileIterator is used to load elements from a tile in global memory, stored in
|
||||
column-major layout, into a fragment, and a corresponding StoreTileIterator is used to store the
|
||||
elements back into global memory (in the same column-major layout).
|
||||
|
||||
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
|
||||
In this example, a PredicatedTileIterator is used to load elements from a tile in global memory,
|
||||
stored in column-major layout, into a fragment and then back into global memory in the same
|
||||
layout.
|
||||
|
||||
This example uses CUTLASS utilities to ease the matrix operations.
|
||||
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
// CUTLASS includes
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/tile_traits_standard.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
|
||||
//
|
||||
// CUTLASS utility includes
|
||||
// CUTLASS utility includes
|
||||
//
|
||||
|
||||
// Defines operator<<() to write TensorView objects to std::ostream
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
// Defines cutlass::HostMatrix<>
|
||||
#include "tools/util/host_matrix.h"
|
||||
// Defines cutlass::HostTensor<>
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
// Defines cutlass::reference::device::TensorInitialize()
|
||||
#include "tools/util/reference/device/tensor_elementwise.h"
|
||||
// Defines cutlass::reference::host::TensorFill() and
|
||||
// cutlass::reference::host::TensorFillBlockSequential()
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
|
||||
// Defines cutlass::reference::host::TensorEquals()
|
||||
#include "tools/util/reference/host/tensor_elementwise.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// This function defines load and store tile iterators to load and store a M-by-K tile, in
|
||||
// column-major layout, from and back into global memory.
|
||||
//
|
||||
#pragma warning( disable : 4503)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Traits>
|
||||
__global__ void cutlass_tile_iterator_load_store_global(
|
||||
float const *input,
|
||||
float *output,
|
||||
int M,
|
||||
int K) {
|
||||
/// Define PredicatedTileIterators to load and store a M-by-K tile, in column major layout.
|
||||
|
||||
// Define a tile load iterator
|
||||
typedef cutlass::TileLoadIterator<
|
||||
Traits, // the Traits type, defines shape/distribution of accesses
|
||||
float, // elements are of type float
|
||||
cutlass::IteratorAdvance::kH, // post-increment accesses advance in strided (as opposed to
|
||||
// contiguous dimension
|
||||
cutlass::MemorySpace::kGlobal // iterator loads from global memory
|
||||
> TileLoadIterator;
|
||||
template <typename Iterator>
|
||||
__global__ void copy(
|
||||
typename Iterator::Params dst_params,
|
||||
typename Iterator::Element *dst_pointer,
|
||||
typename Iterator::Params src_params,
|
||||
typename Iterator::Element *src_pointer,
|
||||
cutlass::Coord<2> extent) {
|
||||
|
||||
// Defines a tile store iterator
|
||||
typedef cutlass::TileStoreIterator<
|
||||
Traits, // the Traits type, defines shape/distribution of accesses
|
||||
float, // elements are of type float
|
||||
cutlass::IteratorAdvance::kH, // post-increment accesses advance in strided (as opposed to
|
||||
// contiguous) dimension
|
||||
cutlass::MemorySpace::kGlobal // iterator stores into global memory
|
||||
> TileStoreIterator;
|
||||
|
||||
// Defines a predicate vector for managing statically sized vector of boolean predicates
|
||||
typedef typename TileLoadIterator::PredicateVector PredicateVector;
|
||||
Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x);
|
||||
Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x);
|
||||
|
||||
// The parameters specified to the iterators. These include the pointer to the source of
|
||||
// addressable memory, and the strides and increments for each of the tile's dimensions
|
||||
typename TileLoadIterator::Params load_params;
|
||||
typename TileStoreIterator::Params store_params;
|
||||
// PredicatedTileIterator uses PitchLinear layout and therefore takes in a PitchLinearShape.
|
||||
// The contiguous dimension can be accessed via Iterator::Shape::kContiguous and the strided
|
||||
// dimension can be accessed via Iterator::Shape::kStrided
|
||||
int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided;
|
||||
|
||||
// Initializing the parameters for both of the iterators. The TileLoadIterator accesses the
|
||||
// input matrix and TileStoreIterator accesses the output matrix. The strides are set
|
||||
// identically since the data is being stored in the same way as it is loaded (column-major
|
||||
// mapping).
|
||||
load_params.initialize(input, M*K, M, 1);
|
||||
store_params.initialize(output, M*K, M, 1);
|
||||
|
||||
// Constructing the tile load and store iterators, and the predicates vector
|
||||
TileLoadIterator load_iterator(load_params);
|
||||
TileStoreIterator store_iterator(store_params);
|
||||
PredicateVector predicates;
|
||||
typename Iterator::Fragment fragment;
|
||||
|
||||
// Initializing the predicates with bounds set to <1, K, M>. This protects out-of-bounds loads.
|
||||
load_iterator.initialize_predicates(predicates.begin(), cutlass::make_Coord(1, K, M));
|
||||
for(int i = 0; i < fragment.size(); ++i) {
|
||||
fragment[i] = 0;
|
||||
}
|
||||
|
||||
// The fragment in which the elements are loaded into and stored from.
|
||||
typename TileLoadIterator::Fragment fragment;
|
||||
src_iterator.load(fragment);
|
||||
dst_iterator.store(fragment);
|
||||
|
||||
// Loading a tile into a fragment and advancing to the next tile's position
|
||||
load_iterator.load_post_increment(fragment, predicates.begin());
|
||||
// Storing a tile from fragment and advancing to the next tile's position
|
||||
store_iterator.store_post_increment(fragment);
|
||||
|
||||
++src_iterator;
|
||||
++dst_iterator;
|
||||
|
||||
for(; iterations > 1; --iterations) {
|
||||
|
||||
src_iterator.load(fragment);
|
||||
dst_iterator.store(fragment);
|
||||
|
||||
++src_iterator;
|
||||
++dst_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launches cutlass_tile_iterator_load_store_global kernel
|
||||
cudaError_t test_cutlass_tile_iterator() {
|
||||
cudaError_t result = cudaSuccess;
|
||||
// Initializes the source tile with sequentially increasing values and performs the copy into
|
||||
// the destination tile using two PredicatedTileIterators, one to load the data from addressable
|
||||
// memory into a fragment (regiser-backed array of elements owned by each thread) and another to
|
||||
// store the data from the fragment back into the addressable memory of the destination tile.
|
||||
|
||||
// Creating a M-by-K (128-by-8) tile for this example.
|
||||
static int const M = 128;
|
||||
static int const K = 8;
|
||||
// The kernel is launched with 128 threads per thread block.
|
||||
static int const kThreadsPerThreadBlock = 128;
|
||||
// Define the tile type
|
||||
typedef cutlass::Shape<1, 8, 128> Tile;
|
||||
cudaError_t TestTileIterator(int M, int K) {
|
||||
|
||||
// CUTLASS provides a standard TileTraits type, which chooses the 'best' shape to enable warp
|
||||
// raking along the contiguous dimension if possible.
|
||||
typedef cutlass::TileTraitsStandard<Tile, kThreadsPerThreadBlock> Traits;
|
||||
// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
|
||||
// PitchLinearShape and PitchLinear layout.
|
||||
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
|
||||
using Layout = cutlass::layout::PitchLinear;
|
||||
using Element = int;
|
||||
int const kThreads = 32;
|
||||
|
||||
// M-by-K input matrix of float
|
||||
cutlass::HostMatrix<float> input(cutlass::MatrixCoord(M, K));
|
||||
// ThreadMaps define how threads are mapped to a given tile. The PitchLinearStripminedThreadMap
|
||||
// stripmines a pitch-linear tile among a given number of threads, first along the contiguous
|
||||
// dimension then along the strided dimension.
|
||||
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;
|
||||
|
||||
// M-by-K output matrix of float
|
||||
cutlass::HostMatrix<float> output(cutlass::MatrixCoord(M, K));
|
||||
// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
Shape, Element, Layout, 1, ThreadMap>;
|
||||
|
||||
//
|
||||
// Initialize input matrix with linear combination.
|
||||
//
|
||||
|
||||
cutlass::Distribution dist;
|
||||
cutlass::Coord<2> copy_extent = cutlass::make_Coord(M, K);
|
||||
cutlass::Coord<2> alloc_extent = cutlass::make_Coord(M, K);
|
||||
|
||||
// Linear distribution in column-major format.
|
||||
dist.set_linear(1, 1, M);
|
||||
// Allocate source and destination tensors
|
||||
cutlass::HostTensor<Element, Layout> src_tensor(alloc_extent);
|
||||
cutlass::HostTensor<Element, Layout> dst_tensor(alloc_extent);
|
||||
|
||||
// Arbitrary RNG seed value. Hard-coded for deterministic results.
|
||||
int seed = 2080;
|
||||
Element oob_value = Element(-1);
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
input.device_view(), // concept: TensorView
|
||||
seed,
|
||||
dist);
|
||||
// Initialize destination tensor with all -1s
|
||||
cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value);
|
||||
// Initialize source tensor with sequentially increasing values
|
||||
cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity());
|
||||
|
||||
// Initialize output matrix to all zeroes.
|
||||
output.fill(0);
|
||||
dst_tensor.sync_device();
|
||||
src_tensor.sync_device();
|
||||
|
||||
// Launch kernel to load and store tiles from/to global memory.
|
||||
cutlass_tile_iterator_load_store_global<Traits><<<
|
||||
dim3(1, 1, 1),
|
||||
dim3(kThreadsPerThreadBlock, 1)
|
||||
>>>(input.device_data(), output.device_data(), M, K);
|
||||
typename Iterator::Params dst_params(dst_tensor.layout());
|
||||
typename Iterator::Params src_params(src_tensor.layout());
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
dim3 block(kThreads, 1);
|
||||
dim3 grid(1, 1);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
// Launch copy kernel to perform the copy
|
||||
copy<Iterator><<< grid, block >>>(
|
||||
dst_params,
|
||||
dst_tensor.device_data(),
|
||||
src_params,
|
||||
src_tensor.device_data(),
|
||||
copy_extent
|
||||
);
|
||||
|
||||
// Copy results to host
|
||||
output.sync_host();
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if(result != cudaSuccess) {
|
||||
std::cerr << "Error - kernel failed." << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Verify results
|
||||
for(int i = 0; i < M; ++i) {
|
||||
for(int j = 0; j < K; ++j) {
|
||||
if(output.at(cutlass::make_Coord(i, j)) != float(M*j+i+1)){
|
||||
std::cout << "FAILED: (" << i << ", " << j
|
||||
<< ") -- expected: " << (M*j+i+1)
|
||||
<< ", actual: " << output.at(cutlass::make_Coord(i, j))
|
||||
<< std::endl;
|
||||
result = cudaErrorUnknown;
|
||||
break;
|
||||
dst_tensor.sync_host();
|
||||
|
||||
// Verify results
|
||||
for(int s = 0; s < alloc_extent[1]; ++s) {
|
||||
for(int c = 0; c < alloc_extent[0]; ++c) {
|
||||
|
||||
Element expected = Element(0);
|
||||
|
||||
if(c < copy_extent[0] && s < copy_extent[1]) {
|
||||
expected = src_tensor.at({c, s});
|
||||
}
|
||||
else {
|
||||
expected = oob_value;
|
||||
}
|
||||
|
||||
Element got = dst_tensor.at({c, s});
|
||||
bool equal = (expected == got);
|
||||
|
||||
if(!equal) {
|
||||
std::cerr << "Error - source tile differs from destination tile." << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to tile_iterator example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 04_tile_iterator
|
||||
//
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
// Properties of CUDA device
|
||||
cudaDeviceProp device_properties;
|
||||
|
||||
// Assumne the device id is 0.
|
||||
int device_id = 0;
|
||||
|
||||
cudaError_t result = cudaGetDeviceProperties(&device_properties, device_id);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to get device properties: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
cudaError_t result = TestTileIterator(57, 35);
|
||||
|
||||
if(result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Run the CUTLASS tile iterator test.
|
||||
//
|
||||
|
||||
result = test_cutlass_tile_iterator();
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
// Exit
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
Reference in New Issue
Block a user