CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -232,7 +232,7 @@ int main() {
## Launching a GEMM kernel in CUDA
**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores.
**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores.
_Note, this example uses CUTLASS Utilities. Be sure `tools/util/include` is listed as an include path._
```c++
@ -289,7 +289,7 @@ int main() {
//
// Launch GEMM on the device
//
status = gemm_op({
{M, N, K},
{ptrA, lda}, // TensorRef to A device tensor
@ -315,7 +315,7 @@ Note, the above could be simplified as follows using helper methods defined in `
//
// Use the TensorRef returned by HostTensor::device_ref().
//
//
status = gemm_op({
{M, N, K},
@ -329,7 +329,7 @@ Note, the above could be simplified as follows using helper methods defined in `
## Launching a GEMM kernel using CUTLASS 3.0 or newer
**Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores.
**Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores.
```c++
#include "cutlass/cutlass.h"
@ -367,7 +367,7 @@ int main(int argc, char const **args) {
using TilesShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
@ -425,10 +425,10 @@ int main(int argc, char const **args) {
StrideC stride_C;
StrideD stride_D;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{}));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
block_A.reset(M * K);
block_B.reset(K * N);
@ -438,7 +438,7 @@ int main(int argc, char const **args) {
//
// Launch GEMM on the device
//
status = gemm_op({
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
@ -462,9 +462,9 @@ int main(int argc, char const **args) {
The [CUTLASS Library](/tools/library) defines an API for managing and executing collections of compiled
kernel instances and launching them from host code without template instantiations in client code.
The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its
kernel selection procedure is intended only to be functionally sufficient. It may not launch the
optimal tile size for a given problem. It chooses the first available kernel whose data types,
The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its
kernel selection procedure is intended only to be functionally sufficient. It may not launch the
optimal tile size for a given problem. It chooses the first available kernel whose data types,
layouts, and alignment constraints satisfy the given problem. Kernel instances and a data structure
describing them are completely available to client applications which may choose to implement their
own selection logic.
@ -479,12 +479,12 @@ by several SDK examples.
* [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu)
The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor
layouts, math operation classes, complex transformations, and more.
layouts, math operation classes, complex transformations, and more.
Client applications should specify [`tools/library/include`](/tools/library/include) in their
include paths and link against libcutlas_lib.so.
The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies
The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies
its dependency on the CUTLASS Library with the following CMake command.
```
target_link_libraries(
@ -534,7 +534,7 @@ int main() {
//
// CUTLASS Library call to execute device GEMM
//
cutlass::library::Handle handle;
//
@ -571,7 +571,7 @@ int main() {
ptrD, // pointer to D matrix in device memory
ldd // leading dimension of D matrix
);
if (status != cutlass::Status::kSuccess) {
return -1;
}
@ -580,27 +580,27 @@ int main() {
}
```
# Example CMake Commands
# Example CMake Commands
To instantiate all operations supporting all tile sizes, data types, and alignment constraints, specify
To instantiate all operations supporting all tile sizes, data types, and alignment constraints, specify
`-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`.
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all
```
The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures.
Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result
The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures.
Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result
in a large binary size and on some platforms linker to fail on building the library.
Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size
Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size
and avoiding linker limitations on some platforms.
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON
```
It is advised to only compile CUTLASS kernels for NVIDIA architectures one plans on running. Furthermore, kernels
can be selectively included in the CUTLASS Library by specifying filter strings and wildcard characters when executing CMake.
It is advised to only compile CUTLASS kernels for NVIDIA architectures one plans on running. Furthermore, kernels
can be selectively included in the CUTLASS Library by specifying filter strings and wildcard characters when executing CMake.
Several examples are defined below for convenience. They may be combined as a comma-delimited list.
Several examples are defined below for convenience. They may be combined as a comma-delimited list.
Compling only the kernels desired reduces compilation time.
@ -646,7 +646,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS='50;60;61;70;75;80' -DCUTLASS_LIBRARY_KERNELS=sf
$ cmake .. -DCUTLASS_NVCC_ARCHS='80' -DCUTLASS_LIBRARY_KERNELS=s16816fprop_*_f16
```
**Example.** All backward weight gradient (wgrad) convolution kernels with FP32 accumulation, FP16 input, and optimized global memory iterator
**Example.** All backward weight gradient (wgrad) convolution kernels with FP32 accumulation, FP16 input, and optimized global memory iterator
targeting NVIDIA Ampere, Turing, and Volta Tensor Core operations
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s*wgrad_optimized_f16