CUTLASS 3.8 Release (#2059)
* CUTLASS 3.8 Release
* update
* Update README.md
* Revert "Update README.md"
This reverts commit b353e36fe8.
* update
* update
---------
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
584
media/docs/blackwell_functionality.md
Normal file
584
media/docs/blackwell_functionality.md
Normal file
@ -0,0 +1,584 @@
|
||||
# Blackwell SM100 GEMMs
|
||||
|
||||
[**TLDR; jump to block scaled GEMM example**](#detailed_blockscale_example)
|
||||
|
||||
Blackwell SM100 introduces `tcgen05.mma` instructions. `tcgen05.mma` instructions support all legacy types (`tfloat32_t`, `half_t`, `bfloat16_t`, `int8_t`, `uint8_t`) and
|
||||
the new 4, 6, and 8-bits floating point datatypes with and without scale factors.
|
||||
This document explains the new `tcgen05.mma` instructions supported by CUTLASS and how one can leverage CUTLASS to create
|
||||
efficient SM100 GEMM kernels targeting these new mma instructions.
|
||||
|
||||
Blackwell SM100 has 7 new `tcgen05.mma` instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions.
|
||||
|
||||
| Ptx Instruction | Throughput | Notes |
|
||||
|----------------------------------------------------------------------------------|----------------------------|-------|
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts |
|
||||
|
||||
For more detailed information see [`tcgen05.mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensorcore-5th-generation-family-instructions).
|
||||
|
||||
## New in Blackwell SM100
|
||||
|
||||
### Block Scaled GEMMs
|
||||
|
||||
Instructions with `kind` modifiers `mxf8f6f4`, `mxf4`, and `nvf4mxf4` perform matrix multiplication operations with scale
|
||||
factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that
|
||||
every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$,
|
||||
$A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated
|
||||
$N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is
|
||||
$D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32).
|
||||
Further details can be found in
|
||||
[PTX documentation on block scaling](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-block-scaling).
|
||||
|
||||
### Blackwell Narrow Precision Data Types
|
||||
|
||||
Narrow-precision `tcgen05.mma` instructions can operate on several 4, 6, and 8-bit data types. Blackwell MMAs can operate
|
||||
on five different 8-bit floating point values, of which only two (`float_ue8m0_t` and `float_ue4m3_t`) can be used as scale factor data types.
|
||||
There are two 6-bit floating point types and one 4-bit floating point data type.
|
||||
See [PTX documentation for narrow precision data types](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats) for details.
|
||||
|
||||
**Blackwell Narrow Precision Data Types**
|
||||
| Data Type | Exponent Bits | Mantissa Bits | Signed | Bit Size |
|
||||
|-------------------|---------------|---------------|--------|----------|
|
||||
| float_e4m3_t |4 |3 | Yes | 8 |
|
||||
| float_e5m2_t |5 |2 | Yes | 8 |
|
||||
| float_e2m3_t |2 |3 | Yes | 6 |
|
||||
| float_e3m2_t |3 |2 | Yes | 6 |
|
||||
| float_e2m1_t |2 |1 | Yes | 4 |
|
||||
| float_ue8m0_t[^1] |8 |0 | No | 8 |
|
||||
| float_ue4m3_t[^1] |4 |3 | No | 8 |
|
||||
|
||||
[^1]: Only valid as scale factor data types.
|
||||
|
||||
Block scaled MMAs use `mx` and `nv` types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. `mx` types follow OCP specification (see [OCP Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels:
|
||||
|
||||
**Blackwell Block Scaled Narrow Precision Data Types**
|
||||
| Mx/Nv Data Type |Scale Factor Type | SF Vector Size | OCP Compliant |
|
||||
|----------------------------|------------------|----------------|---------------|
|
||||
| mx_float8_t\<Any F8type\> |float_ue8m0_t |32 | Yes |
|
||||
| mx_float6_t\<Any F6Type\> |float_ue8m0_t |32 | Yes |
|
||||
| mx_float4_t |float_ue8m0_t |32 | Yes |
|
||||
| nv_float4_t |float_ue4m3_t |16 | No |
|
||||
|
||||
## Layouts, Tensor Alignment Requirements to Target `tcgen05.mma` Instructions
|
||||
|
||||
Tables below list valid data type, and AB layout combinations. Note that the alignment is reported as number of elements. A and B matrix layouts are
|
||||
represented with T and N. T represents row-major layouts, and N represents column-major layouts. For instance, TN is
|
||||
row-major A matrix with column-major B matrix.
|
||||
|
||||
For legacy types (`tf32`, `f16`, `bf16`, `i8` and `u8`) alignment requirements for A and B matrices are the same as in Hopper.
|
||||
All four layouts (TT, NN, NT, TT) are supported for all legacy data types.
|
||||
|
||||
**Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types** <a id="legacy_gemm_table" name="legacy_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------|
|
||||
|1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | |
|
||||
|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|
||||
|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|
||||
|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
|
||||
|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
|
||||
|
||||
For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions.
|
||||
Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision
|
||||
tensors from global memory to shared memory
|
||||
(see [PTX doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-tensor-copy-restrictions) for details).
|
||||
|
||||
Below tables list valid layout, and alignment values for each A and B data type combination and their target `tcgen05.mma`
|
||||
instructions supported by CUTLASS.
|
||||
|
||||
**Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** <a id="non_bs_gemm_table" name="non_bs_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------|
|
||||
|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)|
|
||||
|
||||
|
||||
**Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** <a id="bs_gemm_table" name="bs_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test|
|
||||
|-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------|
|
||||
|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)|
|
||||
|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)|
|
||||
|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)|
|
||||
|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)|
|
||||
|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)|
|
||||
|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)|
|
||||
|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)|
|
||||
|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)|
|
||||
|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)|
|
||||
|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)|
|
||||
|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)<br>[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)|
|
||||
|
||||
## MMA tile shapes supported
|
||||
|
||||
The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid `MmaTileShape`,
|
||||
Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_table), [Table 2](#non_bs_gemm_table), and [Table 3](#bs_gemm_table).
|
||||
|
||||
**Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)** <a id="legacy_rows" name="legacy_rows"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|------------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)** <a id="nonbs_rows_1_2_3_6" name="nonbs_rows_1_2_3_6"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)** <a id="nonbs_rows_5_8" name="nonbs_rows_5_8"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)** <a id="nonbs_rows_4_7" name="nonbs_rows_4_7"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)** <a id="nonbs_rows_9" name="nonbs_rows_9"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
|
||||
**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)** <a id="bs_rows_1" name="bs_rows_1"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|----------------------------------------|
|
||||
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
|
||||
**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)** <a id="bs_rows_2" name="bs_rows_2"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|----------------------------------------|
|
||||
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
|
||||
**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)** <a id="bs_rows_3" name="bs_rows_3"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)** <a id="bs_rows_4_5_7_8_10" name="bs_rows_4_5_7_8_10"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)** <a id="bs_rows_6_9_11" name="bs_rows_6_9_11"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN| TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
## Epilogue config supported
|
||||
|
||||
**Table 14: Epilogue Dispatch Policy** <a id="epi_dispatch" name="epi_dispatch"></a>
|
||||
| 1/2 SM | Epilogue Dispatch Policy |
|
||||
|--------|------------------------------------------|
|
||||
| 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm |
|
||||
| 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm |
|
||||
|
||||
**Table 15: Epilogue PerSmTileShape_MNK** <a id="epi_persmtileshape" name="epi_persmtileshape"></a>
|
||||
| 1/2 SM | MMA tile Shape | PerSmTileShape_MNK |
|
||||
|--------|--------------------------|-------------------------|
|
||||
| 1SM | 64x64xMMA_TileShape_K | 64x64xMMA_TileShape_K |
|
||||
| 1SM | 64x128xMMA_TileShape_K | 64x128xMMA_TileShape_K |
|
||||
| 1SM | 64x192xMMA_TileShape_K | 64x192xMMA_TileShape_K |
|
||||
| 1SM | 64x256xMMA_TileShape_K | 64x256xMMA_TileShape_K |
|
||||
| 1SM | 128x64xMMA_TileShape_K | 128x64xMMA_TileShape_K |
|
||||
| 1SM | 128x128xMMA_TileShape_K | 128x128xMMA_TileShape_K |
|
||||
| 1SM | 128x192xMMA_TileShape_K | 128x192xMMA_TileShape_K |
|
||||
| 1SM | 128x256xMMA_TileShape_K | 128x256xMMA_TileShape_K |
|
||||
| 2SM | 128x64xMMA_TileShape_K | 64x64xMMA_TileShape_K |
|
||||
| 2SM | 128x128xMMA_TileShape_K | 64x128xMMA_TileShape_K |
|
||||
| 2SM | 128x192xMMA_TileShape_K | 64x192xMMA_TileShape_K |
|
||||
| 2SM | 128x256xMMA_TileShape_K | 64x256xMMA_TileShape_K |
|
||||
| 2SM | 256x64xMMA_TileShape_K | 128x64xMMA_TileShape_K |
|
||||
| 2SM | 256x128xMMA_TileShape_K | 128x128xMMA_TileShape_K |
|
||||
| 2SM | 256x192xMMA_TileShape_K | 128x192xMMA_TileShape_K |
|
||||
| 2SM | 256x256xMMA_TileShape_K | 128x256xMMA_TileShape_K |
|
||||
|
||||
MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config we defined in MMA tile shapes supported section.
|
||||
|
||||
### Auto Kernel Dispatch Policies
|
||||
|
||||
In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision
|
||||
GEMMs, and block scaled narrow-precision GEMMs.
|
||||
|
||||
CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building
|
||||
these kernels is to use direct kernel dispatch policies shown in the above tables.
|
||||
|
||||
* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma`.
|
||||
* `KernelTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
* `KernelTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
|
||||
Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueScheduleAuto`.
|
||||
|
||||
## Building a Block Scaled Kernel <a id="detailed_blockscale_example" name="detailed_blockscale_example"></a>
|
||||
|
||||
For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-gemm-kernel). An example dense GEMM can be found:
|
||||
1. [Blackwell FP16 GEMM example](../../examples/70_blackwell_gemm/).
|
||||
|
||||
Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface
|
||||
(as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to
|
||||
A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel
|
||||
which are listed above.
|
||||
|
||||
Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory:
|
||||
1. [NVF4 Gemm with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
2. [NVF4 Gemm with block scaling and NVF4 output matrix](../../examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
|
||||
Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described
|
||||
[here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface.
|
||||
As in all Blackwell kernels, the `TileShape_MNK` argument expects the `MmaTileShape_MNK` which is the tile shape needed
|
||||
by 1 or 2 SM `tcgen05.mma` instructions.
|
||||
|
||||
Let's consider building a block scaled GEMM where the A matrix is of type `mx_float4_t` and column-major (N), and the
|
||||
B matrix is of type `mx_float4_t` and row-major (T). We first need to describe the A and B tensors, and find the
|
||||
instruction that can support the selected A and B type and layout pair. Then, we will choose the performance parameters.
|
||||
|
||||
The skeleton C++ code is shown below:
|
||||
|
||||
```cpp
|
||||
///////////////////////////////////////////////////////////
|
||||
// Mainloop Builder Setup
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
///////////////////////////////////////////
|
||||
// 1. Describe A and B tensors
|
||||
///////////////////////////////////////////
|
||||
using ElementA = // TBD
|
||||
constexpr int AlignA = // TBD
|
||||
using GmemLayoutA = // TBD
|
||||
using ElementB = // TBD
|
||||
constexpr int AlignB = // TBD
|
||||
using GmemLayoutB = // TBD
|
||||
|
||||
// Mma's accumulator type
|
||||
using ElementAccumulator = float; // Always float for block scaled tcgen05.mma instructions
|
||||
|
||||
//////////////////////////////////////////
|
||||
// 2. Choose Performance Parameters
|
||||
//////////////////////////////////////////
|
||||
|
||||
// Tile and cluster shapes
|
||||
// Collective MMA takes tile shape of the MMA operation as input
|
||||
using KernelMainloopPolicy = // TBD
|
||||
using MmaTileShape_MNK = // TBD
|
||||
using ClusterShape_MNK = // TBD
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec
|
||||
ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement
|
||||
ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement
|
||||
ElementAccumulator, // Mma instruction accumulator type
|
||||
MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape
|
||||
// Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelMainloopPolicy // Kernel schedule policy.
|
||||
// Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
```
|
||||
|
||||
From the valid type and layout combinations [Table 3](#bs_gemm_table), we see that only **row 3** can support `mx_float4_t`x`mx_float4_t`
|
||||
combination with NT layout. As a result, we need to use the `tcgen05.mma.kind:mxf8f6f4` instruction. Additionally, in order
|
||||
to use `tcgen05.mma.kind:mxf8f6f4`, we see that A and B tensors both should be 128-element aligned.
|
||||
Thus, we can describe A and B tensors as follows:
|
||||
|
||||
```cpp
|
||||
///////////////////////////////////////////////////////////
|
||||
// Mainloop Builder Setup
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
///////////////////////////////////////////
|
||||
// 1. Describe A and B tensors
|
||||
///////////////////////////////////////////
|
||||
using ElementA = mx_float4_t;
|
||||
constexpr int AlignA = 128;
|
||||
using GmemLayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementB = mx_float4_t;
|
||||
constexpr int AlignB = 128;
|
||||
using GmemLayoutB = cutlass::layout::RowMajor;
|
||||
```
|
||||
Next, we need to choose the performance parameters such as `MmaTileShape_MNK`, `KernelMainloopPolicy`,
|
||||
and `ClusterShape_MNK`.
|
||||
|
||||
`MmaTileShape_MNK` supported for `mx_float4_t`x`mx_float4_t` with `mxf8f6f4` are listed in [Table 11](#bs_rows_3).
|
||||
For NT layout, we see that 3 `MmaTileShape_MNK` are supported: `128x128x128`, and `128x256x128` with 1SM instruction;
|
||||
and `256x256x128` with 2SM instruction. Let's say, we expect to get the best performance with `256x256x128` MMA tile shape
|
||||
for our GEMM problem. Then, we need to set the `KernelMainloopPolicy` to `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100`.
|
||||
Now, we need to choose the `ClusterShape_MNK`. Since we have selected a 2SM mma instruction, `ClusterShape_MNK` should be
|
||||
compatible and its first mode should be a multiple of 2. `ClusterShape_MNK = cute::Shape<_2, [_1|_2|_4], _1>` or
|
||||
`ClusterShape_MNK = cute::Shape<_4, [_1|_2|_4], _1>` would be valid options. Let's choose `cute::Shape<_4,_4,_1>`.
|
||||
Our performance parameters looks like below:
|
||||
|
||||
```cpp
|
||||
//////////////////////////////////////////
|
||||
// 2. Choose Performance Parameters
|
||||
//////////////////////////////////////////
|
||||
|
||||
// Tile and cluster shapes
|
||||
// Collective MMA takes tile shape of the MMA operation as input
|
||||
using KernelMainloopPolicy = cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100;
|
||||
using MmaTileShape_MNK = cute::Shape<_256,_256,_128>;
|
||||
using ClusterShape_MNK = cute::Shape<_4,_4,_1>;
|
||||
```
|
||||
|
||||
After we config the main-loop, let's setup the epilogue.
|
||||
A normal epilogue looks like below, we need to specify the output layout, datatype, alignment and PerSmTileShape_MNK, and let others to be default/auto.
|
||||
|
||||
PerSmTileShape_MNK should be deduced from the mainloop setup. For example, in above mainloop setup, the MmaTileShape_MNK is
|
||||
256x256x128 and the KernelMainloopPolicy is 2sm policy.
|
||||
It means each CTA is doing (256 / 2sm) x 256 x 128 output, so the PerSmTileShape_MNK is 128x256x128. The possible PerSmTileShape_MNK
|
||||
is listed in [Table 15](#epi_persmtileshape)
|
||||
|
||||
The epilogue scheduling policy is configurable, and it is common to set `cutlass::epilogue::TmaWarpSpecialized2Sm`
|
||||
to allow the epilogue builder to automatically select the appropriate policy. However, it can also be explicitly defined to
|
||||
use other policies based on the 1sm or 2sm MMA instruction. The available policies are listed in [Table 14](#epi_dispatch).
|
||||
|
||||
```cpp
|
||||
// Describe C and D tensors
|
||||
using ElementC = cutlass::half_t;
|
||||
constexpr int AlignC = 8;
|
||||
using GmemLayoutC = cutlass::layout::RowMajor;
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
constexpr int AlignD = 32;
|
||||
using GmemLayoutD = cutlass::layout::RowMajor;
|
||||
// Mma's accumulator type
|
||||
using ElementAccumulator = float;
|
||||
// Epilogue computation's precision type
|
||||
using ElementCompute = float;
|
||||
// Cluster size for multicast
|
||||
using ClusterShape_MNK = Shape<_4,_4,_1>;
|
||||
// Collective Epilogue takes the output tile shape for 1 CTA
|
||||
using PerSmTileShape_MNK = Shape<_128,_256,_128>;
|
||||
|
||||
//
|
||||
// Construct CollectiveEpilogue
|
||||
//
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec
|
||||
PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape
|
||||
ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue
|
||||
ElementC, GmemLayoutC, AlignC, // C tensor description
|
||||
ElementD, GmemLayoutD, AlignD, // D tensor description
|
||||
cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
```
|
||||
|
||||
If we want to let the epilogue generate mxf4/nvf4/mxf6/mxf8 (i.e. elements + block-scalefactor), we need to setup the epilogue fusion into the builder.
|
||||
First, we need to choose a SFDVectorSize indicates how many elements sharing the same block-scalefactor.
|
||||
Then, we need to choose ElementSFD and GmemLayoutSFD which indicates the output datatype and which output-dim is used to generate the block-scalefactor.
|
||||
Typically, GmemLayoutSFD would be same as the GmemLayoutD.
|
||||
|
||||
```cpp
|
||||
//
|
||||
// Construct FusionOperation
|
||||
//
|
||||
constexpr int SFDVectorSize = 16;
|
||||
// Define the fusion operation applied during epilogue
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
SFDVectorSize,
|
||||
ElementD, ElementCompute,
|
||||
ElementSFD, GmemLayoutSFD,
|
||||
ElementC
|
||||
>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec
|
||||
PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape
|
||||
ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue
|
||||
ElementC, GmemLayoutC, AlignC, // C tensor description
|
||||
ElementD, GmemLayoutD, AlignD, // D tensor description
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
FusionOperation // <================================== Pass the fusion config into epilogue builder.
|
||||
>::CollectiveOp;
|
||||
```
|
||||
|
||||
Above example made a gentle introduction to using the fusion operations in the epilogue. For more detailed example, see
|
||||
[Blackwell GEMM with collective builder](../../examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu)
|
||||
|
||||
Note that we have first discussed the CollectiveMainloop, then the CollectiveEpilogue for clarity.
|
||||
However, the CollectiveMainloop needs to know the SMEM utilization of the epilogue. Therefore, it needs to be setup before the CollectiveMainloop. See [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory for full kernel and run setup.
|
||||
|
||||
### Scale Factor Layouts
|
||||
|
||||
The scale factor layout consists of a 512B basic-block structure, as illustrated in the diagram below. Each block contains 128 M/N dimension and 4 scale factors (SF) along the K dimension.
|
||||
The byte order of the basic storage chunk is row-major, meaning that M0SF0 to M0SF3, M32SF0 to M32SF3, M64SF0 to M64SF3, and M96SF0 to M96SF3 are stored consecutively in GMEM.
|
||||
|
||||
[](../images/M128xK4_scalefactor_gmem.png)
|
||||
<p align="center">
|
||||
<img src="../images/M128xK4_scalefactor_gmem.png" alt="/M128xK4_scalefactor_gmem.png"/>
|
||||
</p>
|
||||
|
||||
If the scale factor tensor exceeds M128xSF4, it indicates that there are multiple basic blocks along both the M and SFK dimensions. The arrangement of these basic blocks follows a K-major order. Here is a diagram illustrating the scenario where M equals 512 and the SFK is 16.
|
||||
|
||||
[](../images/narrow_precison_multiple_block_sf_layout.png)
|
||||
<p align="center">
|
||||
<img src="../images/narrow_precison_multiple_block_sf_layout.png" alt="/narrow_precison_multiple_block_sf_layout.png"/>
|
||||
</p>
|
||||
|
||||
The creation of scale factor tensors' layouts are tedious. CUTLASS provides `Sm100BlockScaledConfig` to create these layouts easily
|
||||
(See [sm100_blockscaled_layout.hpp](cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp)).
|
||||
The interface to create SFA and SFB tensor layouts is as follows:
|
||||
|
||||
```cpp
|
||||
auto problem_shape = make_shape(M, N, K, L);
|
||||
using SfConfig = Sm100BlockScaledConfig<SFVecSize>;
|
||||
|
||||
// SFA shape: ((32,4), ceil(M/128)), ((SFVecSize,4), ceil(K/4), L)
|
||||
auto layout_sfa = SfConfig::tile_atom_to_shape_SFA(problem_shape);
|
||||
// SFB shape: ((32,4), ceil(N/128)), ((SFVecSize,4), ceil(K/4), L)
|
||||
auto layout_sfb = SfConfig::tile_atom_to_shape_SFB(problem_shape);
|
||||
|
||||
auto tensor_sfa = make_tensor(aptr, layout_sfa);
|
||||
auto tensor_sfb = make_tensor(bptr, layout_sfb);
|
||||
// Access SF for for element m,k of A tensor
|
||||
auto val_a_mk = tensor_sfa(make_coord(m,k,0));
|
||||
```
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
@ -2,19 +2,24 @@
|
||||
|
||||
# Dependent kernel launches
|
||||
|
||||
The Hopper architecture supports a new feature through which two kernels in the same stream can
|
||||
The Hopper and Blackwell architectures supports a new feature through which two kernels in the same stream can
|
||||
overlap their execution, named
|
||||
[Programmatic Dependent Launch (PDL)](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization).
|
||||
This allows kernels with conflict in global memory to programmatically and safely overlap portions
|
||||
of their execution. Primary kernel can signal it is about to finish execution, and the next kernel can
|
||||
optionally wait on the previous kernel to finish flushing its memory.
|
||||
of their execution. Primary kernel can signal it is about to finish execution, and the next kernel is expected to
|
||||
programatically wait on the previous kernel to finish flushing its memory.
|
||||
|
||||
We enable PDL by setting a flag through the extended CUDA launch APIs. All CUTLASS kernels with PDL support
|
||||
will wait on the prior kernel to flush its output to memory and signal the next kernel to start. This means
|
||||
they can safely be dropped in with any other set of kernels using PDL as long as they also adhear to waiting on
|
||||
the prior to flush its memory as well.
|
||||
|
||||
For more information, we refer you to the [PDL section in the CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization).
|
||||
|
||||
## Using dependent launch in CUTLASS
|
||||
|
||||
When building CUTLASS, you can use the `CUTLASS_ENABLE_GDC_FOR_SM90` macro to
|
||||
enable PDL-related instructions in Hopper kernels:
|
||||
When building CUTLASS, you can use the `CUTLASS_ENABLE_GDC_FOR_SM90` and `CUTLASS_ENABLE_GDC_FOR_SM100` macro
|
||||
respectively to enable PDL-related instructions:
|
||||
|
||||
```
|
||||
cmake . -DCUTLASS_ENABLE_GDC_FOR_SM90=1
|
||||
@ -30,3 +35,10 @@ gemm.run(
|
||||
/* launch_with_pdl = */ true
|
||||
);_
|
||||
```
|
||||
## Model-Aware Optimizations with PDL
|
||||
|
||||
In [example 63](../../examples/63_hopper_gemm_with_weight_prefetch/README.md), we use PDL to explicitly optimize for
|
||||
performance of kernels where we know that one of the input matricies (our weights) will not be produced by a prior
|
||||
kernel. In that case, we only need to wait on the prior kernels memory flush in order to load the other input matrix
|
||||
(our activations). During our prologue, we can prefetch our weights to improve performance for memory bandwidth-bound
|
||||
problem sizes. For more informations we refer the reader to [the example](../../examples/63_hopper_gemm_with_weight_prefetch/README.md).
|
||||
|
||||
@ -219,7 +219,11 @@ which has to happen at the end among the participating warps.
|
||||
This is because each warp computes using only a "slice" of CtaTileK,
|
||||
so each warp only has a partial sum before the reduction.
|
||||
|
||||
### Warp Specialization
|
||||
### Hopper Warp Specialization
|
||||
|
||||
Note: the following section on warp-specialization contains details that are specific
|
||||
to the Hopper kernel design. Blackwell SM100 kernels have a substantially different warp-specialization structure,
|
||||
however, the concept of separating out producer and consumer agents still applies.
|
||||
|
||||
Starting with Hopper, CUTLASS 3.0 incorporates the concept of [Warp Specialization](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization)
|
||||
as part of the kernel design. A thread block is partitioned into two sets of warps, [*producer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [*consumer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp). The *producer* warp group loads data from global memory into shared memory buffers using the new [Tensor Memory Accelerator (TMA)](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/).
|
||||
|
||||
@ -20,6 +20,20 @@ CUTLASS defines classes for the following numeric data types.
|
||||
* `tfloat32_t`: Tensor Float 32 data type (exponent: 8b, mantissa: 10b; literal suffix `_tf32`)
|
||||
* `int4_t`, `uint4_t`: 4b signed and unsigned integer (literal suffx `_s4`, `_u4`)
|
||||
* `bin1_t`: 1b binary numeric type (literal suffix `_b1`)
|
||||
* `float_e5m2_t`: 8bits signed float (exponent: 5 bits, mantissa: 2 bits)
|
||||
* `float_e4m3_t`: 8bits signed float (exponent: 4 bits, mantissa: 3 bits)
|
||||
* `float_ue4m3_t`: 8bits unsigned float (exponent: 4 bits, mantissa: 3 bits)
|
||||
* `float_ue8m0_t`: 8bits unsigned float (exponent: 8 bits, mantissa: 0 bits)
|
||||
* `float_e3m2_t`: 6bits signed float (exponent: 3 bits, mantissa: 2 bits)
|
||||
* `float_e2m3_t`: 6bits signed float (exponent: 2 bits, mantissa: 3 bits)
|
||||
* `float_e2m1_t`: 4bits signed float (exponent: 2 bits, mantissa: 1 bits)
|
||||
* `type_erased_dynamic_float8_t`: Type agnostic 8 bits signed float allowing the user to provide a specific datatype as runtime argument.
|
||||
* `type_erased_dynamic_float6_t`: Type agnostic 6 bits signed float allowing the user to provide a specific datatype as runtime argument.
|
||||
* `type_erased_dynamic_float4_t`: Type agnostic 4 bits signed float allowing the user to provide a specific datatype as runtime argument.
|
||||
* `mx_float8_t<float_e5m2_t>` or `mx_float8_t<float_e4m3_t>` : Block scaled data type with fp8 element type and float_ue8m0_t scale factor and vector size of 32.
|
||||
* `mx_float6_t<float_e3m2_t>` or `mx_float6_t<float_e2m3_t>` : Block scaled data type with fp6 element type and float_ue8m0_t scale factor and vector size of 32.
|
||||
* `mx_float6_t<float_e2m1_t>` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 32.
|
||||
* `nv_float4_t<float_e2m1_t>` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 16.
|
||||
* `complex<T>`: defines complex-valued data type based on the supplied real-valued numeric type
|
||||
|
||||
Numeric types in CUTLASS may be used in both host and device code and are intended to function
|
||||
|
||||
@ -115,6 +115,10 @@ usage:
|
||||
("s1688" and "nt") or ("s844" and "tn" and "align8") in their
|
||||
operation name using --kernels="s1688*nt, s884*tn*align8"
|
||||
|
||||
--kernels-file=<path> Same behavior as `kernels`, but kernel names are specified in a file with
|
||||
one kernel name on each line. Set of profiled kernels is the union of kernels
|
||||
specified here and those specified in `kernels`.
|
||||
|
||||
--ignore-kernels=<string_list> Excludes kernels whose names match anything in this list.
|
||||
|
||||
Device:
|
||||
@ -284,6 +288,8 @@ GEMM
|
||||
[int] --max_cc,--maximum-compute-capability Maximum device compute capability
|
||||
[enum] --raster_order={heuristic|H|along_m|M|along_n|N} If supported by kernel, sets the tile raster direction
|
||||
[int] --swizzle_size={1,2,4,8} If supported by kernel, sets the 2D tile swizzle extent (In Hopper, other values will be rounded down to the nearest supported value)
|
||||
[int] --use_pdl,--use-pdl Use PDL (true, false)
|
||||
|
||||
Examples:
|
||||
|
||||
Profile a particular problem size:
|
||||
@ -323,6 +329,8 @@ Profile when execution is performed on device 0 and the C tensor is located on a
|
||||
|
||||
The format of tensor argument is followed by `<type>:<layout>`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`.
|
||||
|
||||
CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information.
|
||||
|
||||
## Example CUDA Core GEMM Operation
|
||||
|
||||
Example command line for profiling SGEMM kernels is as follows:
|
||||
|
||||
@ -24,7 +24,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=100a # compiles for NVIDIA Blackwell SM100 GPU architecture
|
||||
```
|
||||
|
||||
If your goal is strictly to build only the CUTLASS Profiler and to minimize compilation time, we suggest
|
||||
@ -653,6 +654,105 @@ targeting NVIDIA Ampere, Turing, and Volta Tensor Core operations
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s*wgrad_optimized_f16
|
||||
```
|
||||
|
||||
## Instantiating a Blackwell SM100 GEMM kernel
|
||||
|
||||
Blackwell SM100 kernels are instantiated very similarly to Hopper kernels. Let us start with an
|
||||
[FP8 GEMM without blockscaling](../../test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu)
|
||||
as an example.
|
||||
|
||||
The kernel starts with setting up datatypes and cluster shapes.
|
||||
```c++
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::float_e4m3_t;
|
||||
using ElementD = cutlass::float_e4m3_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>;
|
||||
using ClusterShape = Shape<_1,_1,_1>;
|
||||
using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{}));
|
||||
using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{}));
|
||||
using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{}));
|
||||
```
|
||||
|
||||
The epilogue needs to be instantiated first as the mainloop collective builder takes the shared memory budget of epilogue in the template parameter list. The 3.x epilogue collective builder API has not changed
|
||||
for Blackwell, so the epilogue fusion is built in a same way as an SM90 epilogue.
|
||||
|
||||
```c++
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
ElementBias
|
||||
>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
OutputCtaShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementD, LayoutC, 16 / sizeof(ElementD),
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
```
|
||||
|
||||
One can refer to our Sm100 unit tests as examples of how to correctly
|
||||
choose mainloop schedules. All of our dispatch policies can be found in [dispatch_policy.hpp](../../include/cutlass/gemm/dispatch_policy.hpp)
|
||||
and more comprehensive Blackwell specific documentation for valid
|
||||
dispatch policies can be in [blackwell_functionality.md](./blackwell_functionality.md).
|
||||
|
||||
```c++
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 16 / sizeof(ElementA),
|
||||
ElementB, LayoutB, 16 / sizeof(ElementB),
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
```
|
||||
|
||||
It is worth noting that the mainloop builder takes `MmaTileShape` while the epilogue builder takes `OutputCtaShape`.
|
||||
|
||||
Instantiating a blockscaled GEMM kernel is slightly different. Referring to an [MXFP8 GEMM](./../../test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu) sample unit test, it takes a different tensor operation class:
|
||||
|
||||
```c++
|
||||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
|
||||
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
|
||||
```
|
||||
|
||||
are needed in the mainloop builder:
|
||||
|
||||
```c++
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, GmemLayoutA, 16,
|
||||
ElementB, GmemLayoutB, 16,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
```
|
||||
|
||||
We encourage a user to refer to Sm100 unit tests and the generated profiler-based kernels as more comprehensive samples.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
BIN
media/images/M128xK4_scalefactor_gmem.png
Normal file
BIN
media/images/M128xK4_scalefactor_gmem.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 219 KiB |
1
media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg
Executable file
1
media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg
Executable file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 9.8 KiB |
BIN
media/images/narrow_precison_multiple_block_sf_layout.png
Normal file
BIN
media/images/narrow_precison_multiple_block_sf_layout.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 37 KiB |
Reference in New Issue
Block a user