releaase 2.11 (#703)
This commit is contained in:
@ -4,21 +4,21 @@
|
||||
|
||||
# Efficient GEMM in CUDA
|
||||
|
||||
CUTLASS implements the hierarchically blocked structure described in
|
||||
CUTLASS implements the hierarchically blocked structure described in
|
||||
[CUTLASS: Fast Linear Algebra in CUDA C++](https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/)
|
||||
and the [CUTLASS GTC2018 talk](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
## Hierarchical Structure
|
||||
|
||||
The basic triple loop nest computing matrix multiply may be blocked and tiled to match
|
||||
concurrency in hardware, memory locality, and parallel programming models. In CUTLASS,
|
||||
concurrency in hardware, memory locality, and parallel programming models. In CUTLASS,
|
||||
GEMM is mapped to NVIDIA GPUs with the structure illustrated by the following loop nest.
|
||||
|
||||
```c++
|
||||
for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // for each threadblock_y } threadblock-level concurrency
|
||||
for (int cta_m = 0; cta_m < GemmM; cta_m += CtaTileM) { // for each threadblock_x }
|
||||
|
||||
for (int cta_k = 0; cta_k < GemmK; cta_k += CtaTileK) { // "GEMM mainloop" - no unrolling
|
||||
for (int cta_k = 0; cta_k < GemmK; cta_k += CtaTileK) { // "GEMM mainloop" - no unrolling
|
||||
// - one iteration of this loop is one "stage"
|
||||
//
|
||||
for (int warp_n = 0; warp_n < CtaTileN; warp_n += WarpTileN) { // for each warp_y } warp-level parallelism
|
||||
@ -30,7 +30,7 @@ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // f
|
||||
for (int mma_k = 0; mma_k < WarpTileK; mma_k += MmaK) { // for each mma instruction } instruction-level parallelism
|
||||
for (int mma_n = 0; mma_n < WarpTileN; mma_n += MmaN) { // for each mma instruction }
|
||||
for (int mma_m = 0; mma_m < WarpTileM; mma_m += MmaM) { // for each mma instruction }
|
||||
//
|
||||
//
|
||||
mma_instruction(d, a, b, c); // TensorCore matrix computation
|
||||
|
||||
} // for mma_m
|
||||
@ -47,17 +47,17 @@ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // f
|
||||
```
|
||||
|
||||
This tiled loop nest targets concurrency among
|
||||
- threadblocks
|
||||
- warps
|
||||
- CUDA and Tensor Cores
|
||||
- threadblocks,
|
||||
- warps, and
|
||||
- CUDA and Tensor Cores.
|
||||
|
||||
and takes advantage of memory locality within
|
||||
- shared memory
|
||||
- registers
|
||||
It takes advantage of memory locality within
|
||||
- shared memory and
|
||||
- registers.
|
||||
|
||||
The flow of data within this structure is illustrated below.
|
||||
This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a
|
||||
nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a
|
||||
The figure below illustrates the flow of data within this structure.
|
||||
This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a
|
||||
nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a
|
||||
level within the memory hierarchy, becoming increasingly finer moving left to right.
|
||||
|
||||

|
||||
@ -66,20 +66,19 @@ level within the memory hierarchy, becoming increasingly finer moving left to ri
|
||||
### Threadblock-level GEMM
|
||||
|
||||
Each threadblock computes its portion of the output GEMM by iteratively loading tiles of input
|
||||
matrices and computing an accumulated matrix product. At the threadblock level, data is loaded from
|
||||
global memory. The blocking strategy in general is key to achieving efficiency. However, there are
|
||||
multiple conflicting goals that a programmer aims to achieve to strike a reasonable compromise. A
|
||||
matrices and computing an accumulated matrix product. At the threadblock level, data are loaded from
|
||||
global memory. The blocking strategy in general is key to achieving efficiency. However, the programmer
|
||||
must balance multiple conflicting goals. A
|
||||
larger threadblock means fewer fetches from global memory, thereby ensuring that DRAM bandwidth
|
||||
does not become a bottleneck.
|
||||
|
||||
does not become a bottleneck.
|
||||
However, large threadblock tiles may not match the dimensions of the problem well. If either the
|
||||
GEMM _M_ or _N_ dimension is small, some threads within the threadblock may not perform meaningful
|
||||
work, as the threadblock may be partially outside the bounds of the problem. If both _M_ and _N_
|
||||
are small while _K_ is large, this scheme may launch relatively few threadblocks and fail to
|
||||
fully utilize all multiprocessors within the GPU. Strategies to optimize performance for this case
|
||||
are described in the section [Parallelized Reductions](efficient_gemm.md#parallelized-reductions)
|
||||
which partition the GEMM K dimension across multiple threadblocks or multiple warps. These compute
|
||||
matrix products in parallel which is then reduced to compute the result.
|
||||
make full use of all multiprocessors within the GPU. Strategies to optimize performance for this case,
|
||||
as described in the section [Parallelized Reductions](efficient_gemm.md#parallelized-reductions),
|
||||
partition the GEMM K dimension across multiple threadblocks or multiple warps. These threadblocks
|
||||
or warps compute matrix products in parallel; the products are then reduced to compute the result.
|
||||
|
||||
In CUTLASS, the dimensions of the threadblock tile are specified as `ThreadblockShape::{kM, kN, kK}`
|
||||
and may be tuned to specialize the GEMM computation for the target processor and dimensions of
|
||||
@ -90,10 +89,10 @@ the GEMM problem.
|
||||
|
||||
The warp-level GEMM maps to the warp-level parallelism within the CUDA execution model. Multiple
|
||||
warps within a threadblock fetch data from shared memory into registers and perform computations.
|
||||
Warp-level GEMMs may be implemented either by TensorCores issuing
|
||||
[mma.sync](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma)
|
||||
or [wmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma)
|
||||
instructions or by thread-level matrix computations issued to CUDA cores.
|
||||
Warp-level GEMMs may be implemented either by TensorCores issuing
|
||||
[mma.sync](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma)
|
||||
or [wmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma)
|
||||
instructions, or by thread-level matrix computations issued to CUDA cores.
|
||||
For maximum performance, access to shared memory should be bank conflict free. To maximize data
|
||||
reuse within the warp, a large warp-level GEMM tile should be chosen.
|
||||
|
||||
@ -101,8 +100,8 @@ reuse within the warp, a large warp-level GEMM tile should be chosen.
|
||||
### Thread-level GEMM
|
||||
|
||||
At the lowest level of blocking, each thread is responsible for processing a certain number of
|
||||
elements. Threads cannot access each other's registers so we choose an organization that enables
|
||||
values held in registers to be reused for multiple math instructions. This results in a 2D tiled
|
||||
elements. Threads cannot access each other's registers, so we choose an organization that enables
|
||||
reuse of values held in registers for multiple math instructions. This results in a 2D tiled
|
||||
structure within a thread, in which each thread issues a sequence of independent math instructions
|
||||
to the CUDA cores and computes an accumulated outer product.
|
||||
|
||||
@ -127,31 +126,33 @@ but other device-side function call operators may be used to perform custom oper
|
||||
|
||||
## Optimizations
|
||||
|
||||
The hierarchical structure described above yields an efficient mapping to the CUDA execution model and
|
||||
The hierarchical structure described above yields an efficient mapping to the CUDA execution model and
|
||||
CUDA/TensorCores in NVIDIA GPUs. The following sections describe strategies for obtaining peak performance
|
||||
for all corners of the design space, maximizing parallelism and exploiting data locality wherever possible.
|
||||
|
||||
### Pipelining
|
||||
|
||||
The blocked structure demands a large storage allocation within the registers of each CUDA thread. The
|
||||
accumulator elements typically occupy at least half a thread's total register budget. Consequently,
|
||||
accumulator elements typically occupy at least half a thread's total register budget. Consequently,
|
||||
occupancy -- the number of concurrent threads, warps, and threadblocks -- is relatively low compared
|
||||
to other classes of GPU workloads. This limits the GPUs ability to hide memory latency and other stalls
|
||||
to other classes of GPU workloads. This limits the GPU's ability to hide memory latency and other stalls
|
||||
by context switching to other concurrent threads within an SM.
|
||||
|
||||
To mitigate the effects of memory latency, *software pipelining* is used to overlap memory accesses
|
||||
with other computation within a thread. In CUTLASS, this is achieved by double buffering at the
|
||||
following scopes
|
||||
To mitigate the effects of memory latency, CUTLASS uses *software pipelining* to overlap memory accesses
|
||||
with other computation within a thread. CUTLASS accomplishes this by double buffering at the
|
||||
following scopes.
|
||||
|
||||
- **threadblock-scoped shared memory tiles:** two tiles are allocated within shared memory; one is used
|
||||
load data for the current matrix operation, while the other tile is used to buffer data loaded from
|
||||
global memory for the next mainloop iteration
|
||||
- **Threadblock-scoped shared memory tiles:** two tiles are allocated in shared memory.
|
||||
One is used to load data for the current matrix operation,
|
||||
while the other tile is used to buffer data loaded from global memory
|
||||
for the next mainloop iteration.
|
||||
|
||||
- **warp-scoped matrix fragments:** two fragments are allocated within registers; one fragment is passed
|
||||
to CUDA and TensorCores during the current matrix computation, while the other is used to receive
|
||||
shared memory fetch returns for the next warp-level matrix operation
|
||||
- **Warp-scoped matrix fragments:** two fragments are allocated within registers.
|
||||
One fragment is passed to CUDA and TensorCores during the current matrix computation,
|
||||
while the other is used to receive shared memory fetch returns
|
||||
for the next warp-level matrix operation.
|
||||
|
||||
The efficient, pipelined mainloop body used in CUTLASS GEMMs is illustrated as follows.
|
||||
The following diagram illustrates the efficient, pipelined mainloop body used in CUTLASS GEMMs.
|
||||
|
||||

|
||||
|
||||
@ -181,35 +182,42 @@ benefits of large threadblock-level GEMM tiles.
|
||||
|
||||
CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
|
||||
and launching an additional set of threadblocks for each partition. Consequently, we refer to
|
||||
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass
|
||||
requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called
|
||||
batched reduction.
|
||||
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" strategy
|
||||
requires the execution of 2 kernels: partitionedK GEMM, and batched reduction.
|
||||
|
||||
The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users
|
||||
to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the
|
||||
number of partition that will be applied along K dimension for operand A and B. For example, parameters o
|
||||
f m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of
|
||||
m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count.
|
||||
PartitionedK GEMM resembles one flavor of batched strided GEMM. Instead of requiring users
|
||||
to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the
|
||||
number of partitions that will be applied along the K dimension for operands A and B. For example,
|
||||
parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs
|
||||
with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible
|
||||
by the partition count.
|
||||
|
||||
For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs
|
||||
with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.
|
||||
For example, parameters of m=128, n=128, k=4096 and partition=20
|
||||
will result in 20 batched strided GEMMs.
|
||||
The first 19 batches will have m=128, n=128, and k=4096/20=204,
|
||||
and the last batch will have m=128, n=128, and k=220.
|
||||
|
||||
The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of
|
||||
the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by
|
||||
the users to store this intermediate results.
|
||||
The batched reduction kernel takes as input the output (C) of partitionedK GEMM,
|
||||
and performs a reduction along the K-dimension.
|
||||
Users must manage workspace memory to store this intermediate result.
|
||||
|
||||
**Sliced K - reduction across warps**
|
||||
|
||||
Similar to the split-k scenario, sliced-k aims at improving the efficiency of kernels with smaller M, N,
|
||||
but large K dimensions. In general at the thread-block level, the parameters CtaTileN, CtaTileM expose parallelism
|
||||
by partitioning the the work the among warps, and larger warpTiles expose better ILP (Instruction
|
||||
level parallelism) and reuse, but it also limits the number of warps running per thread-block, which reduces efficiency.
|
||||
Similar to the split-k scenario, sliced-k aims at improving the efficiency of kernels
|
||||
with smaller M and N dimensions, but large K dimension.
|
||||
At the thread-block level, the parameters CtaTileN and CtaTileM expose parallelism
|
||||
by partitioning the work among warps.
|
||||
Larger warpTiles expose better instruction-level parallelism (ILP) and reuse,
|
||||
but also limit the number of warps running per threadblock, which reduces efficiency.
|
||||
|
||||
So in order to improve efficiency in such scenarios, partitioning the warpTiles also along ctaTileK helps improve the utilization
|
||||
of the underlying hardware by allowing more warps to run concurrently in a CTA. Now, since sliced-k kernels breaks
|
||||
down a thread-blocks's computation among participating warps not just among the CtaTileN, CtaTileM dimension,
|
||||
but also the CtaTileK dimension it entails a small cost in form of a reduction which has to happen at the end among the
|
||||
participating warps - since each warp now owns a partial sum (since they compute using only a "slice" of ctaTileK).
|
||||
In order to improve efficiency in such scenarios, partitioning the warpTiles also along ctaTileK
|
||||
helps use the hardware more efficiently by allowing more warps to run concurrently in a CTA.
|
||||
Sliced-k kernels break down a threadblock's computation among participating warps
|
||||
not just among the CtaTileN, CtaTileM dimension, but also the CtaTileK dimension.
|
||||
Thus, sliced-k entails a small cost in form of a reduction
|
||||
which has to happen at the end among the participating warps.
|
||||
This is because each warp computes using only a "slice" of CtaTileK,
|
||||
so each warp only has a partial sum before the reduction.
|
||||
|
||||
# Resources
|
||||
|
||||
|
||||
Reference in New Issue
Block a user